gridwise_elementwise_2d.hpp Source File

gridwise_elementwise_2d.hpp Source File#

Composable Kernel: gridwise_elementwise_2d.hpp Source File
gridwise_elementwise_2d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
15
16namespace ck {
17
18template <typename GridwiseElementwiseFunctor,
19 typename InGridDescTuple,
20 typename OutGridDescTuple,
21 typename InDataTypePointerTuple,
22 typename OutDataTypePointerTuple,
23 typename Block2TileMap,
24 typename ElementwiseOperation>
25__global__ void
26#if CK_USE_LAUNCH_BOUNDS
28#endif
29 kernel_elementwise(const InGridDescTuple in_grid_desc_tuple,
30 const OutGridDescTuple out_grid_desc_tuple,
31 const InDataTypePointerTuple p_in_global_tuple,
32 const OutDataTypePointerTuple p_out_global_tuple,
33 const Block2TileMap block_2_tile_map,
34 const ElementwiseOperation elementwise_op)
35{
36 GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
37 out_grid_desc_tuple,
38 p_in_global_tuple,
39 p_out_global_tuple,
40 block_2_tile_map,
41 elementwise_op);
42}
43
44template <typename GridwiseElementwiseFunctorA,
45 typename GridwiseElementwiseFunctorB,
46 typename InAGridDescTuple,
47 typename InBGridDescTuple,
48 typename OutAGridDescTuple,
49 typename OutBGridDescTuple,
50 typename InADataTypePointerTuple,
51 typename InBDataTypePointerTuple,
52 typename OutADataTypePointerTuple,
53 typename OutBDataTypePointerTuple,
54 typename Block2TileMapA,
55 typename Block2TileMapB,
56 typename ElementwiseOperation>
57__global__ void
58#if CK_USE_LAUNCH_BOUNDS
60#endif
61 kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a,
62 const InBGridDescTuple in_grid_desc_tuple_b,
63 const OutAGridDescTuple out_grid_desc_tuple_a,
64 const OutBGridDescTuple out_grid_desc_tuple_b,
65 const InADataTypePointerTuple p_in_global_tuple_a,
66 const InBDataTypePointerTuple p_in_global_tuple_b,
67 const OutADataTypePointerTuple p_out_global_tuple_a,
68 const OutBDataTypePointerTuple p_out_global_tuple_b,
69 const Block2TileMapA block_2_tile_map_a,
70 const Block2TileMapB block_2_tile_map_b,
71 const ElementwiseOperation elementwise_op,
72 const index_t a_grid_size)
73{
74 if(get_block_1d_id() < a_grid_size)
75 {
76 GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
77 out_grid_desc_tuple_a,
78 p_in_global_tuple_a,
79 p_out_global_tuple_a,
80 block_2_tile_map_a,
81 elementwise_op,
83 }
84 else
85 {
86 GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
87 out_grid_desc_tuple_b,
88 p_in_global_tuple_b,
89 p_out_global_tuple_b,
90 block_2_tile_map_b,
91 elementwise_op,
92 get_block_1d_id() - a_grid_size);
93 }
94}
95
96template <typename GridwiseElementwiseFunctorA,
97 typename GridwiseElementwiseFunctorB,
98 typename InAGridDescTuple,
99 typename InBGridDescTuple,
100 typename OutAGridDescTuple,
101 typename OutBGridDescTuple,
102 typename InADataTypePointerTuple,
103 typename InBDataTypePointerTuple,
104 typename OutADataTypePointerTuple,
105 typename OutBDataTypePointerTuple,
106 typename Block2TileMapA,
107 typename Block2TileMapB,
108 typename ElementwiseOperation,
109 index_t NumInputsA,
110 index_t NumInputsB,
111 index_t NumOutputsA,
112 index_t NumOutputsB>
113__global__ void
114#if CK_USE_LAUNCH_BOUNDS
116#endif
117 kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a,
118 const InBGridDescTuple in_grid_desc_tuple_b,
119 const OutAGridDescTuple out_grid_desc_tuple_a,
120 const OutBGridDescTuple out_grid_desc_tuple_b,
121 const InADataTypePointerTuple p_in_global_tuple_a,
122 const InBDataTypePointerTuple p_in_global_tuple_b,
123 const OutADataTypePointerTuple p_out_global_tuple_a,
124 const OutBDataTypePointerTuple p_out_global_tuple_b,
125 const Block2TileMapA block_2_tile_map_a,
126 const Block2TileMapB block_2_tile_map_b,
127 const ElementwiseOperation elementwise_op,
128 const index_t a_grid_size,
129 const index_t batch_count_a,
130 const index_t batch_count_b,
131 const std::array<index_t, NumInputsA> input_batch_strides_a,
132 const std::array<index_t, NumInputsB> input_batch_strides_b,
133 const std::array<index_t, NumOutputsA> output_batch_strides_a,
134 const std::array<index_t, NumOutputsB> output_batch_strides_b)
135{
136 static_assert(InAGridDescTuple::Size() == NumInputsA &&
137 InADataTypePointerTuple::Size() == NumInputsA);
138 static_assert(OutAGridDescTuple::Size() == NumOutputsA &&
139 OutADataTypePointerTuple::Size() == NumOutputsA);
140 static_assert(InBGridDescTuple::Size() == NumInputsB &&
141 InBDataTypePointerTuple::Size() == NumInputsB);
142 static_assert(OutBGridDescTuple::Size() == NumOutputsB &&
143 OutBDataTypePointerTuple::Size() == NumOutputsB);
144
145 const index_t block_id = __builtin_amdgcn_readfirstlane(get_block_1d_id());
146
147 if(block_id < a_grid_size)
148 {
149 const index_t num_blocks_per_batch =
150 __builtin_amdgcn_readfirstlane(a_grid_size / batch_count_a);
151 const index_t g_idx = __builtin_amdgcn_readfirstlane(block_id / num_blocks_per_batch);
152
153 InADataTypePointerTuple p_in_global_with_offset_tuple;
154 OutADataTypePointerTuple p_out_global_with_offset_tuple;
155
156 static_for<0, InADataTypePointerTuple::Size(), 1>{}([&](auto i) {
157 p_in_global_with_offset_tuple(i) =
158 p_in_global_tuple_a.At(i) +
159 type_convert<long_index_t>(input_batch_strides_a[i]) * g_idx;
160 });
161
162 static_for<0, OutADataTypePointerTuple::Size(), 1>{}([&](auto i) {
163 p_out_global_with_offset_tuple(i) =
164 p_out_global_tuple_a.At(i) +
165 type_convert<long_index_t>(output_batch_strides_a[i]) * g_idx;
166 });
167
168 GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
169 out_grid_desc_tuple_a,
170 p_in_global_with_offset_tuple,
171 p_out_global_with_offset_tuple,
172 block_2_tile_map_a,
173 elementwise_op,
174 block_id);
175 }
176 else
177 {
178 const index_t num_blocks_per_batch =
179 __builtin_amdgcn_readfirstlane((get_grid_size() - a_grid_size) / batch_count_b);
180 const index_t g_idx =
181 __builtin_amdgcn_readfirstlane((block_id - a_grid_size) / num_blocks_per_batch);
182
183 InBDataTypePointerTuple p_in_global_with_offset_tuple;
184 OutBDataTypePointerTuple p_out_global_with_offset_tuple;
185
186 static_for<0, InBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
187 p_in_global_with_offset_tuple(i) =
188 p_in_global_tuple_b.At(i) +
189 type_convert<long_index_t>(input_batch_strides_b[i]) * g_idx;
190 });
191
192 static_for<0, OutBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
193 p_out_global_with_offset_tuple(i) =
194 p_out_global_tuple_b.At(i) +
195 type_convert<long_index_t>(output_batch_strides_b[i]) * g_idx;
196 });
197
198 GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
199 out_grid_desc_tuple_b,
200 p_in_global_with_offset_tuple,
201 p_out_global_with_offset_tuple,
202 block_2_tile_map_b,
203 elementwise_op,
204 block_id - a_grid_size);
205 }
206}
207
208template <typename GridwiseElementwiseFunctor,
209 typename InGridDescTuple,
210 typename OutGridDescTuple,
211 typename InDataTypePointerTuple,
212 typename OutDataTypePointerTuple,
213 typename Block2TileMap,
214 typename ElementwiseOperation,
215 index_t NumInputs,
216 index_t NumOutputs>
217__global__ void
218#if CK_USE_LAUNCH_BOUNDS
220#endif
221 kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple,
222 const OutGridDescTuple out_grid_desc_tuple,
223 const InDataTypePointerTuple p_in_global_tuple,
224 const OutDataTypePointerTuple p_out_global_tuple,
225 const Block2TileMap block_2_tile_map,
226 const ElementwiseOperation elementwise_op,
227 const index_t batch_count,
228 const std::array<index_t, NumInputs> input_batch_strides,
229 const std::array<index_t, NumOutputs> output_batch_strides)
230{
231 static_assert(InGridDescTuple::Size() == NumInputs &&
232 InDataTypePointerTuple::Size() == NumInputs);
233 static_assert(OutGridDescTuple::Size() == NumOutputs &&
234 OutDataTypePointerTuple::Size() == NumOutputs);
235
236 const index_t num_blocks_per_batch =
237 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
238 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
239
240 InDataTypePointerTuple p_in_global_with_offset_tuple;
241 OutDataTypePointerTuple p_out_global_with_offset_tuple;
242
243 static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) {
244 p_in_global_with_offset_tuple(i) =
245 p_in_global_tuple.At(i) + type_convert<long_index_t>(input_batch_strides[i]) * g_idx;
246 });
247
248 static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) {
249 p_out_global_with_offset_tuple(i) =
250 p_out_global_tuple.At(i) + type_convert<long_index_t>(output_batch_strides[i]) * g_idx;
251 });
252
253 GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
254 out_grid_desc_tuple,
255 p_in_global_with_offset_tuple,
256 p_out_global_with_offset_tuple,
257 block_2_tile_map,
258 elementwise_op);
259}
260
261template <typename InGridDescTuple,
262 typename OutGridDescTuple,
263 typename InDataTypePointerTuple,
264 typename OutDataTypePointerTuple,
265 typename Block2TileMap,
266 typename ElementwiseOperation,
267 index_t BlockSize,
268 index_t M0PerBlock,
269 index_t M1PerBlock,
270 index_t M0PerThread,
271 index_t M1PerThread,
272 typename ThreadClusterArrangeOrder,
273 typename InScalarPerVectorSeq,
274 typename OutScalarPerVectorSeq,
275 index_t SrcVectorDim,
276 index_t DstVectorDim>
278{
279 static constexpr index_t NumInput = InDataTypePointerTuple::Size();
280 static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
281
282 static_assert(NumInput == InScalarPerVectorSeq::Size() &&
283 NumOutput == OutScalarPerVectorSeq::Size() &&
284 NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(),
285 "Tuple size is inconsistent with the number of in/out!");
286
287 static constexpr auto I0 = Number<0>{};
288 static constexpr auto I1 = Number<1>{};
289
290 static_assert((SrcVectorDim == I0 || SrcVectorDim == I1) &&
291 (DstVectorDim == I0 || DstVectorDim == I1),
292 "Vector dim must be equal to 0 or 1.");
293
295
296 __device__ static void Run(const InGridDescTuple& in_grid_desc_tuple,
297 const OutGridDescTuple& out_grid_desc_tuple,
298 const InDataTypePointerTuple& p_in_global_tuple,
299 const OutDataTypePointerTuple& p_out_global_tuple,
300 const Block2TileMap& block_2_tile_map,
301 const ElementwiseOperation& elementwise_op,
302 const index_t block_id = get_block_1d_id())
303 {
304
305 constexpr auto src_datas = generate_tuple(
306 [&](auto I) {
307 using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
309
310 return DataType{};
311 },
313
314 constexpr auto dst_datas = generate_tuple(
315 [&](auto I) {
316 using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
317 using DataType = remove_pointer_t<DataTypePointer>;
318
319 return DataType{};
320 },
322
323 const auto in_global_buf_tuple = generate_tuple(
324 [&](auto I) {
326 p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
327 },
329
330 auto out_global_buf_tuple = generate_tuple(
331 [&](auto I) {
333 p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
334 },
336
337 const auto block_work_idx =
338 block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
339
340 const index_t m0_block_data_idx_on_grid =
341 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
342 const index_t m1_block_data_idx_on_grid =
343 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
344 const auto input_thread_grid_offset = generate_tuple(
345 [&](auto) {
346 return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
347 },
349 const auto output_thread_grid_offset = generate_tuple(
350 [&](auto) {
351 return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
352 },
354
356 // If src and dst have same vector dim, then:
357 // M0 dim - for src and dst vector load/store
358 // else:
359 // M0 dim - for dst vector load
360 // M1 dim - for src vector store
361 using SrcDimAccessOrder =
362 std::conditional_t<SrcVectorDim == I1, Sequence<0, 1>, Sequence<1, 0>>;
363 using DstDimAccessOrder =
364 std::conditional_t<DstVectorDim == I1, Sequence<0, 1>, Sequence<1, 0>>;
365
366 using ThreadClusterLengths =
367 Sequence<Number<M0PerBlock / M0PerThread>{}, Number<M1PerBlock / M1PerThread>{}>;
368
369 auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2<
371 ElementwiseOperation,
374 ThreadClusterLengths,
375 ThreadClusterArrangeOrder,
376 decltype(src_datas),
377 decltype(dst_datas),
378 InGridDescTuple,
379 OutGridDescTuple,
380 SrcDimAccessOrder,
381 DstDimAccessOrder,
382 SrcVectorDim,
383 DstVectorDim,
384 InScalarPerVectorSeq,
385 OutScalarPerVectorSeq,
390 input_thread_grid_offset,
391 out_grid_desc_tuple,
392 output_thread_grid_offset,
393 elementwise_op};
394 global_to_global_transfer.Run(
395 in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
396 }
397};
398
399} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
Definition ck.hpp:268
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition gridwise_elementwise_2d.hpp:221
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition gridwise_elementwise_2d.hpp:61
__global__ void kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size, const index_t batch_count_a, const index_t batch_count_b, const std::array< index_t, NumInputsA > input_batch_strides_a, const std::array< index_t, NumInputsB > input_batch_strides_b, const std::array< index_t, NumOutputsA > output_batch_strides_a, const std::array< index_t, NumOutputsB > output_batch_strides_b)
Definition gridwise_elementwise_2d.hpp:117
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
Definition gridwise_elementwise_2d.hpp:278
static __device__ void Run(const InGridDescTuple &in_grid_desc_tuple, const OutGridDescTuple &out_grid_desc_tuple, const InDataTypePointerTuple &p_in_global_tuple, const OutDataTypePointerTuple &p_out_global_tuple, const Block2TileMap &block_2_tile_map, const ElementwiseOperation &elementwise_op, const index_t block_id=get_block_1d_id())
Definition gridwise_elementwise_2d.hpp:296
Definition utility/sequence.hpp:43
Definition thread_group.hpp:12
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r2.hpp:45
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340