device_gemm_xdl_cshuffle_lds_direct_load.hpp Source File

device_gemm_xdl_cshuffle_lds_direct_load.hpp Source File#

Composable Kernel: device_gemm_xdl_cshuffle_lds_direct_load.hpp Source File
device_gemm_xdl_cshuffle_lds_direct_load.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
17
18namespace ck {
19namespace tensor_operation {
20namespace device {
21
22template <typename ALayout,
23 typename BLayout,
24 typename ELayout,
25 typename ADataType,
26 typename BDataType,
27 typename EDataType,
28 typename AccDataType,
29 typename CShuffleDataType,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CDEElementwiseOperation,
33 GemmSpecialization GemmSpec,
34 index_t NumGemmKPrefetchStage,
35 index_t BlockSize,
36 index_t MPerBlock,
37 index_t NPerBlock,
38 index_t KPerBlock,
39 index_t AK1,
40 index_t BK1,
41 index_t MPerXDL,
42 index_t NPerXDL,
43 index_t MXdlPerWave,
44 index_t NXdlPerWave,
45 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
46 typename ABlockTransferSrcAccessOrder,
47 index_t ABlockTransferSrcVectorDim,
48 index_t ABlockTransferScalarPerVector,
49 bool ABlockLdsExtraM,
50 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
51 typename BBlockTransferSrcAccessOrder,
52 index_t BBlockTransferSrcVectorDim,
53 index_t BBlockTransferScalarPerVector,
54 bool BBlockLdsExtraN,
55 index_t CShuffleMXdlPerWavePerShuffle,
56 index_t CShuffleNXdlPerWavePerShuffle,
57 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
58 index_t CDEBlockTransferScalarPerVector_NPerBlock,
61 typename ComputeDataType = EDataType>
63 BLayout,
64 ELayout,
65 ADataType,
66 BDataType,
67 EDataType,
68 AElementwiseOperation,
69 BElementwiseOperation,
70 CDEElementwiseOperation>
71{
73 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
74 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
75
76 static constexpr auto I1 = Number<1>{};
77
78 template <index_t NXdlPerWave_>
80 ALayout,
81 BLayout,
83 ELayout,
84 ADataType,
85 BDataType,
86 ComputeDataType,
87 AccDataType,
88 CShuffleDataType,
90 EDataType,
91 AElementwiseOperation,
92 BElementwiseOperation,
93 CDEElementwiseOperation,
95 GemmSpec,
96 NumGemmKPrefetchStage,
97 BlockSize,
98 MPerBlock,
99 NPerBlock,
100 KPerBlock,
101 AK1,
102 BK1,
103 MPerXDL,
104 NPerXDL,
105 MXdlPerWave,
106 NXdlPerWave_,
107 ABlockTransferThreadClusterLengths_AK0_M_AK1,
108 ABlockTransferSrcAccessOrder,
109 ABlockTransferSrcVectorDim,
110 ABlockTransferScalarPerVector,
111 ABlockLdsExtraM,
112 BBlockTransferThreadClusterLengths_BK0_N_BK1,
113 BBlockTransferSrcAccessOrder,
114 BBlockTransferSrcVectorDim,
115 BBlockTransferScalarPerVector,
116 BBlockLdsExtraN,
117 CShuffleMXdlPerWavePerShuffle,
118 CShuffleNXdlPerWavePerShuffle,
119 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
120 CDEBlockTransferScalarPerVector_NPerBlock,
121 LoopSched,
122 PipelineVer,
123 ComputeDataType>;
124
127
128 using Argument = typename GridwiseGemm64::Argument;
129
130 struct Invoker : public BaseInvoker
131 {
132 template <typename GridwiseGemm>
133 float RunImp(const typename GridwiseGemm::Argument& arg,
134 const StreamConfig& stream_config = StreamConfig{})
135 {
136 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
137 arg.b_grid_desc_n_k_,
138 arg.ds_grid_desc_m_n_,
139 arg.e_grid_desc_m_n_,
140 arg.block_2_etile_map_))
141 {
142 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
143 }
144
145 const index_t grid_size =
146 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
147
148 auto launch_kernel = [&](auto has_main_k_block_loop) {
149 constexpr bool has_main_loop = has_main_k_block_loop.value;
150
152 GridwiseGemm,
153 ADataType,
154 BDataType,
155 typename GridwiseGemm::DsGridPointer,
156 EDataType,
157 AElementwiseOperation,
158 BElementwiseOperation,
159 CDEElementwiseOperation,
160 typename GridwiseGemm::AGridDesc_AK0_M_AK1,
161 typename GridwiseGemm::BGridDesc_BK0_N_BK1,
162 typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
163 typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
164 typename GridwiseGemm::Block2ETileMap,
165 has_main_loop>;
166
167 return launch_and_time_kernel(stream_config,
168 kernel,
169 dim3(grid_size),
170 dim3(BlockSize),
171 0,
172 arg.p_a_grid_,
173 arg.p_b_grid_,
174 // arg.p_ds_grid_,
175 ck::Tuple<>{},
176 arg.p_e_grid_,
177 arg.a_element_op_,
178 arg.b_element_op_,
179 arg.cde_element_op_,
180 arg.a_grid_desc_ak0_m_ak1_,
181 arg.b_grid_desc_bk0_n_bk1_,
182 arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
183 arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
184 arg.block_2_etile_map_);
185 };
186
187 const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
188
189 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
190 {
191 return launch_kernel(integral_constant<bool, true>{});
192 }
193 else
194 {
195 return launch_kernel(integral_constant<bool, false>{});
196 }
197 }
198
200
201 float Run(const BaseArgument* p_arg,
202 const StreamConfig& stream_config = StreamConfig{}) override
203 {
204 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
205 }
206 };
207
208 static bool IsSupportedArgument(const Argument& arg)
209 {
211 {
212 return false;
213 }
215 {
216 return false;
217 }
218
220 {
221 if(!is_tf32_supported())
222 {
223 return false;
224 }
225 }
226
227 // Check vector load/store.
228 {
231
232 // Check vector load of A.
233 if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
234 {
235 if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
236 {
237 return false;
238 }
239 }
240 else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
241 {
242 if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
243 {
244 return false;
245 }
246 }
247 else
248 {
249 return false;
250 }
251
252 // Check vector load of B.
253 if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
254 {
255 if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
256 {
257 return false;
258 }
259 }
260 else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
261 {
262 if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
263 {
264 return false;
265 }
266 }
267 else
268 {
269 return false;
270 }
271
272 // Check vector load of E.
273 // For now, only the RowMajor layout is supported.
274 if constexpr(is_same_v<ELayout, Row>)
275 {
276 if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
277 {
278 return false;
279 }
280 }
281 else
282 {
283 return false;
284 }
285 }
286
287 if(get_warp_size() == 64)
288 {
289 if constexpr(NXdlPerWave64 > 0)
290 {
291 return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_,
292 arg.b_grid_desc_n_k_,
293 arg.ds_grid_desc_m_n_,
294 arg.e_grid_desc_m_n_,
295 arg.block_2_etile_map_);
296 }
297 }
298 else
299 {
300 if constexpr(NXdlPerWave32 > 0)
301 {
302 return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_,
303 arg.b_grid_desc_n_k_,
304 arg.ds_grid_desc_m_n_,
305 arg.e_grid_desc_m_n_,
306 arg.block_2_etile_map_);
307 }
308 }
309 return false;
310 }
311
312 bool IsSupportedArgument(const BaseArgument* p_arg) override
313 {
314 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
315 }
316
317 static auto MakeArgument(const void* p_a,
318 const void* p_b,
319 void* p_e,
320 index_t MRaw,
321 index_t NRaw,
322 index_t KRaw,
323 index_t StrideA,
324 index_t StrideB,
325 index_t StrideE,
326 AElementwiseOperation a_element_op,
327 BElementwiseOperation b_element_op,
328 CDEElementwiseOperation cde_element_op)
329 {
330 using EmptyDsPointers = std::array<const void*, 0>;
331 using EmptyDsStrides = std::array<ck::index_t, 0>;
332
333 return Argument{p_a,
334 p_b,
335 EmptyDsPointers{},
336 p_e,
337 MRaw,
338 NRaw,
339 KRaw,
340 StrideA,
341 StrideB,
342 EmptyDsStrides{},
343 StrideE,
344 a_element_op,
345 b_element_op,
346 cde_element_op};
347 }
348
349 static auto MakeInvoker() { return Invoker{}; }
350
351 std::unique_ptr<BaseArgument>
352 MakeArgumentPointer(const void* p_a,
353 const void* p_b,
354 void* p_e,
355 index_t MRaw,
356 index_t NRaw,
357 index_t KRaw,
358 index_t StrideA,
359 index_t StrideB,
360 index_t StrideE,
361 AElementwiseOperation a_element_op,
362 BElementwiseOperation b_element_op,
363 CDEElementwiseOperation cde_element_op) override
364 {
365 using EmptyDsPointers = std::array<const void*, 0>;
366 using EmptyDsStrides = std::array<ck::index_t, 0>;
367
368 return std::make_unique<Argument>(p_a,
369 p_b,
370 EmptyDsPointers{},
371 p_e,
372 MRaw,
373 NRaw,
374 KRaw,
375 StrideA,
376 StrideB,
377 EmptyDsStrides{},
378 StrideE,
379 a_element_op,
380 b_element_op,
381 cde_element_op);
382 }
383
384 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
385 {
386 return std::make_unique<Invoker>(Invoker{});
387 }
388
389 std::string GetTypeString() const override
390 {
391 auto str = std::stringstream();
392
393 std::map<LoopScheduler, std::string> LoopSchedToString{
394 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
395
396 std::map<PipelineVersion, std::string> PipelineVersionToString{
398
399 // clang-format off
400 str << "DeviceGemm_Xdl_CShuffle_LdsDirectLoad"
401 << "<"
402 << BlockSize << ", "
403 << MPerBlock << ", "
404 << NPerBlock << ", "
405 << KPerBlock << ", "
406 << AK1 << ", "
407 << BK1 << ", "
408 << MPerXDL << ", "
409 << NPerXDL << ", "
410 << MXdlPerWave << ", "
411 << NXdlPerWave << ", "
412 << ABlockTransferScalarPerVector << ", "
413 << BBlockTransferScalarPerVector << ", "
414 << CShuffleMXdlPerWavePerShuffle << ", "
415 << CShuffleNXdlPerWavePerShuffle << ", "
416 << getGemmSpecializationString(GemmSpec)
417 << ">"
418 << " LoopScheduler: "
419 << LoopSchedToString[LoopSched] << ", "
420 << "PipelineVersion: "
421 << PipelineVersionToString[PipelineVer] << ", "
422 << "Prefetch: "
423 << NumGemmKPrefetchStage;
424 // clang-format on
425
426 return str.str();
427 }
428};
429
430} // namespace device
431} // namespace tensor_operation
432} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
bool is_lds_direct_load_supported()
Definition host_utility/device_prop.hpp:101
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
constexpr bool is_same_v
Definition type.hpp:283
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v4
Definition gridwise_gemm_pipeline_selector.hpp:22
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
__global__ void kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:43
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:149
Definition utility/tuple.hpp:186
Definition utility/integral_constant.hpp:20
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:131
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:201
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:133
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:71
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:349
GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< ALayout, BLayout, ck::Tuple<>, ELayout, ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, ck::Tuple<>, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferScalarPerVector, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer, ComputeDataType > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:79
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:208
static constexpr auto I1
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:76
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:384
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:73
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:128
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:389
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:352
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:312
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:125
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:74
static auto MakeArgument(const void *p_a, const void *p_b, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:317
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_lds_direct_load.hpp:126
Definition device_gemm.hpp:22