gridwise_2d_multiple_reduction_multiblock.hpp Source File

gridwise_2d_multiple_reduction_multiblock.hpp Source File#

Composable Kernel: gridwise_2d_multiple_reduction_multiblock.hpp Source File
gridwise_2d_multiple_reduction_multiblock.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
13
14namespace ck {
15
16template <typename GridwiseMultipleReduction,
17 index_t NumReduction,
18 typename InDataType,
19 typename OutDataTypePointerTuple,
20 typename AccDataType,
21 typename InGridDesc_M_K,
22 typename OutGridDesc_M_Tuple,
23 typename InElementwiseOperationTuple,
24 typename AccElementwiseOperationTuple>
25__global__ void
26kernel_multiple_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
27 const OutGridDesc_M_Tuple out_grid_desc_m_tuple,
28 const InElementwiseOperationTuple in_elementwise_op_tuple,
29 const AccElementwiseOperationTuple acc_elementwise_op_tuple,
30 index_t block_group_size,
31 index_t num_k_block_tile_iteration,
33 const InDataType* const __restrict__ p_in_value_global,
35 OutDataTypePointerTuple p_out_value_global_tuple)
36{
37 GridwiseMultipleReduction::Run(in_grid_desc_m_k,
38 out_grid_desc_m_tuple,
39 in_elementwise_op_tuple,
40 acc_elementwise_op_tuple,
41 block_group_size,
42 num_k_block_tile_iteration,
43 alpha_values,
44 p_in_value_global,
45 beta_values,
46 p_out_value_global_tuple);
47};
48
49template <index_t NumReduction,
50 typename InDataType,
51 typename OutDataTypePointerTuple,
52 typename AccDataType,
53 typename InGridDesc_M_K,
54 typename OutGridDesc_M_Tuple,
55 typename ReduceOperation,
56 typename InElementwiseOperationTuple,
57 typename AccElementwiseOperationTuple,
58 InMemoryDataOperationEnum OutMemoryDataOperation,
59 bool PropagateNan,
60 index_t BlockSize,
61 index_t MThreadClusterSize,
62 index_t KThreadClusterSize,
63 index_t MThreadSliceSize,
64 index_t KThreadSliceSize,
65 index_t InSrcVectorDim,
66 index_t InSrcVectorSize,
67 typename OutDstVectorSizeSeq>
69{
70 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
71 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)),
72 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
73
74 static_assert(NumReduction == OutDataTypePointerTuple::Size() &&
75 NumReduction == OutGridDesc_M_Tuple::Size() &&
76 NumReduction == OutDstVectorSizeSeq::Size() &&
77 NumReduction == InElementwiseOperationTuple::Size() &&
78 NumReduction == AccElementwiseOperationTuple::Size(),
79 "All tuple should have the same size as the number of Reductions!");
80
81 static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
82
84
87
90
91 static constexpr auto thread_cluster_desc =
93
98
100 BlockSize,
103 ReduceOperation,
104 PropagateNan>;
105
109 ReduceOperation,
110 PropagateNan>;
111
113
114 static constexpr auto I0 = Number<0>{};
115 static constexpr auto I1 = Number<1>{};
116
117 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
118 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
119
121
122 __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
123 const OutGridDesc_M_Tuple& out_grid_desc_m_tuple,
124 const InElementwiseOperationTuple& in_elementwise_op_tuple,
125 const AccElementwiseOperationTuple& acc_elementwise_op_tuple,
126 index_t block_group_size,
127 index_t num_k_block_tile_iteration,
129 const InDataType* const __restrict__ p_in_value_global,
131 OutDataTypePointerTuple p_out_value_global_tuple)
132 {
133 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
134
135 // LDS, reused by all reductions
136 __shared__ AccDataType p_reduce_work_buffer[BlockSize];
137
138 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
139 p_in_value_global,
140 in_grid_desc_m_k.GetElementSpaceSize(),
141 ReduceOperation::template GetIdentityValue<InDataType>());
142 auto out_global_val_buf_tuple = generate_tuple(
143 [&](auto iR) {
145 p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize());
146 },
148
149 auto reduce_work_buf =
150 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
151
153 in_thread_buf;
154
155 auto in_thread_buf_tuple = generate_tuple(
156 [&](auto iR) {
157 (void)iR;
159 AccDataType,
160 MThreadSliceSize * KThreadSliceSize,
161 true>{};
162 },
164
165 auto accu_value_buf_tuple = generate_tuple(
166 [&](auto iR) {
167 (void)iR;
169 },
171
172 static_for<0, NumReduction, 1>{}([&](auto iR) {
174 [&](auto J) { accu_value_buf_tuple(iR)(J) = identityVal; });
175 });
176
177 const index_t thread_local_id = get_thread_local_1d_id();
178 const index_t block_global_id = get_block_1d_id();
179 const index_t blkgroup_id = block_global_id / block_group_size;
180 const index_t block_local_id = block_global_id % block_group_size;
181
182 const auto thread_cluster_idx =
183 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
184
185 const auto thread_m_cluster_id = thread_cluster_idx[I0];
186 const auto thread_k_cluster_id = thread_cluster_idx[I1];
187
188 const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
189
190 using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
191 constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
193
194 auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
195 AccDataType,
196 InGridDesc_M_K,
197 decltype(thread_buffer_desc),
198 ThreadBufferLengths,
200 InSrcVectorDim,
201 InSrcVectorSize,
202 1,
203 false>(
204 in_grid_desc_m_k,
205 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
206 block_local_id * reduceSizePerBlock +
207 thread_k_cluster_id * KThreadSliceSize));
208
209 constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
210
211 index_t reducedTiles = 0;
212 do
213 {
214 threadwise_src_load.Run(in_grid_desc_m_k,
215 in_global_val_buf,
216 thread_buffer_desc,
217 make_tuple(I0, I0),
218 in_thread_buf);
219
220 static_for<0, NumReduction, 1>{}([&](auto iR) {
222 // do element-wise pre-reduction operation
224 constexpr auto offset =
225 thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
226 in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
227 in_thread_buf(Number<offset>{}));
228 });
229 });
230
231 ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR));
232 });
233
234 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
235
236 reducedTiles++;
237 } while(reducedTiles < num_k_block_tile_iteration);
238
239 constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
240
241 static_for<0, NumReduction, 1>{}([&](auto iR) {
242 using OutDataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
244
246 BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf_tuple(iR)(I));
247 });
248
250 if(thread_k_cluster_id == 0)
251 {
252 acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I),
253 accu_value_buf_tuple(iR)(I));
254
255 accu_value_buf_tuple(iR)(I) *= alpha_values[iR];
256 }
257 });
258
259 if(thread_k_cluster_id == 0)
260 {
261 if(!float_equal_zero{}(beta_values[iR]))
262 {
264 priorDstValueBuf;
265
266 auto threadwise_dst_load =
268 OutDataType,
269 decltype(out_grid_desc_m_tuple[iR]),
270 decltype(reduced_data_desc),
273 0,
274 OutDstVectorSizeSeq::At(iR),
275 1,
276 false>(
277 out_grid_desc_m_tuple[iR],
278 make_multi_index(blkgroup_id * M_BlockTileSize +
279 thread_m_cluster_id * MThreadSliceSize));
280
281 threadwise_dst_load.Run(out_grid_desc_m_tuple[iR],
282 out_global_val_buf_tuple(iR),
283 reduced_data_desc,
284 make_tuple(I0),
285 priorDstValueBuf);
286
288 accu_value_buf_tuple(iR)(I) +=
289 type_convert<AccDataType>(priorDstValueBuf[I]) * beta_values[iR];
290 });
291 };
292
293 auto threadwise_dst_store =
295 OutDataType,
296 decltype(reduced_data_desc),
297 decltype(out_grid_desc_m_tuple[iR]),
301 0,
302 OutDstVectorSizeSeq::At(iR),
303 OutMemoryDataOperation,
304 1,
305 true>(
306 out_grid_desc_m_tuple[iR],
307 make_multi_index(blkgroup_id * M_BlockTileSize +
308 thread_m_cluster_id * MThreadSliceSize),
309 PassThroughOp{});
310
311 threadwise_dst_store.Run(reduced_data_desc,
312 make_tuple(I0),
313 accu_value_buf_tuple[iR],
314 out_grid_desc_m_tuple[iR],
315 out_global_val_buf_tuple(iR));
316 };
317 });
318 };
319}; // namespace ck
320
321} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_multiple_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M_Tuple out_grid_desc_m_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple, index_t block_group_size, index_t num_k_block_tile_iteration, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition gridwise_2d_multiple_reduction_multiblock.hpp:26
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition utility/array.hpp:14
Definition gridwise_2d_multiple_reduction_multiblock.hpp:69
static constexpr index_t K_BlockTileSize
Definition gridwise_2d_multiple_reduction_multiblock.hpp:118
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_2d_multiple_reduction_multiblock.hpp:112
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_2d_multiple_reduction_multiblock.hpp:85
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_2d_multiple_reduction_multiblock.hpp:83
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M_Tuple &out_grid_desc_m_tuple, const InElementwiseOperationTuple &in_elementwise_op_tuple, const AccElementwiseOperationTuple &acc_elementwise_op_tuple, index_t block_group_size, index_t num_k_block_tile_iteration, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition gridwise_2d_multiple_reduction_multiblock.hpp:122
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_2d_multiple_reduction_multiblock.hpp:96
static constexpr auto I0
Definition gridwise_2d_multiple_reduction_multiblock.hpp:114
static constexpr auto I1
Definition gridwise_2d_multiple_reduction_multiblock.hpp:115
static constexpr index_t M_BlockTileSize
Definition gridwise_2d_multiple_reduction_multiblock.hpp:117
static constexpr auto thread_cluster_desc
Definition gridwise_2d_multiple_reduction_multiblock.hpp:91
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_2d_multiple_reduction_multiblock.hpp:94
static constexpr bool reorder_thread_cluster
Definition gridwise_2d_multiple_reduction_multiblock.hpp:81
detail::AccumulateWithNanCheck< PropagateNan, ReduceOperation, AccDataType > Accumulation
Definition gridwise_2d_multiple_reduction_multiblock.hpp:120
ThreadwiseReduction< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M, ReduceOperation, PropagateNan > ThreadwiseReduce
Definition gridwise_2d_multiple_reduction_multiblock.hpp:106
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_2d_multiple_reduction_multiblock.hpp:88
PartitionedBlockwiseReduction< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, ReduceOperation, PropagateNan > BlockwiseReduce
Definition gridwise_2d_multiple_reduction_multiblock.hpp:99
Definition reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:44
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition reduction_functions_threadwise.hpp:36
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_functions_accumulate.hpp:28
Definition reduction_common.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340