device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp Source File

device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp Source File#

Composable Kernel: device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp Source File
device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
17
18namespace ck {
19namespace tensor_operation {
20namespace device {
21
22template <typename ALayout,
23 typename BLayout,
24 typename DsLayout,
25 typename ELayout,
26 typename ADataType,
27 typename BDataType,
28 typename AccDataType,
29 typename CShuffleDataType,
30 typename DsDataType,
31 typename EDataType,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CDEElementwiseOperation,
35 GemmSpecialization GemmSpec,
36 index_t NumGemmKPrefetchStage,
37 index_t BlockSize,
38 index_t MPerBlock,
39 index_t NPerBlock,
40 index_t KPerBlock,
41 index_t AK1,
42 index_t BK1,
43 index_t MPerXDL,
44 index_t NPerXDL,
45 index_t MXdlPerWave,
46 index_t NXdlPerWave,
47 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
48 typename ABlockTransferSrcAccessOrder,
49 index_t ABlockTransferSrcVectorDim,
50 index_t ABlockTransferScalarPerVector,
51 index_t ABlockLdsExtraM,
52 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
53 typename BBlockTransferSrcAccessOrder,
54 index_t BBlockTransferSrcVectorDim,
55 index_t BBlockTransferScalarPerVector,
56 index_t BBlockLdsExtraN,
57 index_t CShuffleMXdlPerWavePerShuffle,
58 index_t CShuffleNXdlPerWavePerShuffle,
59 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60 index_t CDEBlockTransferScalarPerVector_NPerBlock,
63 typename ComputeDataType = EDataType>
65 : public DeviceGemmMultipleD<ALayout,
66 BLayout,
67 DsLayout,
68 ELayout,
69 ADataType,
70 BDataType,
71 DsDataType,
72 EDataType,
73 AElementwiseOperation,
74 BElementwiseOperation,
75 CDEElementwiseOperation>
76{
78 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
79 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
80 static constexpr auto I1 = Number<1>{};
81 static constexpr index_t NumDTensor = DsDataType::Size();
82
83 template <index_t NXdlPerWave_>
85 ALayout,
86 BLayout,
87 DsLayout,
88 ELayout,
89 ADataType,
90 BDataType,
91 ComputeDataType,
92 AccDataType,
93 CShuffleDataType,
94 DsDataType,
95 EDataType,
96 AElementwiseOperation,
97 BElementwiseOperation,
98 CDEElementwiseOperation,
100 GemmSpec,
101 NumGemmKPrefetchStage,
102 BlockSize,
103 MPerBlock,
104 NPerBlock,
105 KPerBlock,
106 AK1,
107 BK1,
108 MPerXDL,
109 NPerXDL,
110 MXdlPerWave,
111 NXdlPerWave_,
112 ABlockTransferThreadClusterLengths_AK0_M_AK1,
113 ABlockTransferSrcAccessOrder,
114 ABlockTransferSrcVectorDim,
115 ABlockTransferScalarPerVector,
116 ABlockLdsExtraM,
117 BBlockTransferThreadClusterLengths_BK0_N_BK1,
118 BBlockTransferSrcAccessOrder,
119 BBlockTransferSrcVectorDim,
120 BBlockTransferScalarPerVector,
121 BBlockLdsExtraN,
122 CShuffleMXdlPerWavePerShuffle,
123 CShuffleNXdlPerWavePerShuffle,
124 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
125 CDEBlockTransferScalarPerVector_NPerBlock,
126 LoopSched,
127 PipelineVer>;
130
131 using Argument = typename GridwiseGemm64::Argument;
132
133 struct Invoker : public BaseInvoker
134 {
135
136 template <typename GridwiseGemm>
137 float RunImp(const typename GridwiseGemm::Argument& arg,
138 const StreamConfig& stream_config = StreamConfig{})
139 {
140 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
141 arg.b_grid_desc_n_k_,
142 arg.ds_grid_desc_m_n_,
143 arg.e_grid_desc_m_n_,
144 arg.block_2_etile_map_))
145 {
146 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
147 }
148
149 const index_t grid_size =
150 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
151
152 auto launch_kernel = [&](auto has_main_k_block_loop) {
153 constexpr bool has_main_loop = has_main_k_block_loop.value;
154
156 GridwiseGemm,
157 ADataType,
158 BDataType,
159 typename GridwiseGemm::DsGridPointer,
160 EDataType,
161 AElementwiseOperation,
162 BElementwiseOperation,
163 CDEElementwiseOperation,
164 typename GridwiseGemm::AGridDesc_AK0_M_AK1,
165 typename GridwiseGemm::BGridDesc_BK0_N_BK1,
166 typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
167 typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
168 typename GridwiseGemm::Block2ETileMap,
169 has_main_loop>;
170
171 return launch_and_time_kernel(stream_config,
172 kernel,
173 dim3(grid_size),
174 dim3(BlockSize),
175 0,
176 arg.p_a_grid_,
177 arg.p_b_grid_,
178 arg.p_ds_grid_,
179 arg.p_e_grid_,
180 arg.a_element_op_,
181 arg.b_element_op_,
182 arg.cde_element_op_,
183 arg.a_grid_desc_ak0_m_ak1_,
184 arg.b_grid_desc_bk0_n_bk1_,
185 arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
186 arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
187 arg.block_2_etile_map_);
188 };
189
190 const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
191
192 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
193 {
194 return launch_kernel(integral_constant<bool, true>{});
195 }
196 else
197 {
198 return launch_kernel(integral_constant<bool, false>{});
199 }
200 }
201
203
204 float Run(const BaseArgument* p_arg,
205 const StreamConfig& stream_config = StreamConfig{}) override
206 {
207 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
208 }
209 };
210
211 static bool IsSupportedArgument(const Argument& arg)
212 {
214 {
215 return false;
216 }
218 {
219 return false;
220 }
221
222 // Check vector load/store.
223 {
226
227 // Check vector load of A.
228 if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
229 {
230 if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
231 {
232 return false;
233 }
234 }
235 else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
236 {
237 if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
238 {
239 return false;
240 }
241 }
242 else
243 {
244 return false;
245 }
246
247 // Check vector load of B.
248 if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
249 {
250 if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
251 {
252 return false;
253 }
254 }
255 else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
256 {
257 if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
258 {
259 return false;
260 }
261 }
262 else
263 {
264 return false;
265 }
266
267 // Check vector load of Ds.
268 // For now, only the RowMajor layout is supported.
269 bool all_valid = true;
270
271 static_for<0, NumDTensor, 1>{}([&](auto i) {
272 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
273
274 if constexpr(!is_same_v<DLayout, Row>)
275 {
276 all_valid = false;
277 }
278 });
279
280 if(!all_valid)
281 {
282 return false;
283 }
284
285 // Check vector load of E.
286 // For now, only the RowMajor layout is supported.
287 if constexpr(is_same_v<ELayout, Row>)
288 {
289 if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
290 {
291 return false;
292 }
293 }
294 else
295 {
296 return false;
297 }
298 }
299
300 if(get_warp_size() == 64)
301 {
302 if constexpr(NXdlPerWave64 > 0)
303 {
304 return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_,
305 arg.b_grid_desc_n_k_,
306 arg.ds_grid_desc_m_n_,
307 arg.e_grid_desc_m_n_,
308 arg.block_2_etile_map_);
309 }
310 }
311 else
312 {
313 if constexpr(NXdlPerWave32 > 0)
314 {
315 return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_,
316 arg.b_grid_desc_n_k_,
317 arg.ds_grid_desc_m_n_,
318 arg.e_grid_desc_m_n_,
319 arg.block_2_etile_map_);
320 }
321 }
322 return false;
323 }
324
325 bool IsSupportedArgument(const BaseArgument* p_arg) override
326 {
327 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
328 }
329
330 static auto MakeArgument(const void* p_a,
331 const void* p_b,
332 std::array<const void*, NumDTensor> p_ds,
333 void* p_e,
334 index_t MRaw,
335 index_t NRaw,
336 index_t KRaw,
337 index_t StrideA,
338 index_t StrideB,
339 std::array<index_t, NumDTensor> StrideDs,
340 index_t StrideE,
341 AElementwiseOperation a_element_op,
342 BElementwiseOperation b_element_op,
343 CDEElementwiseOperation cde_element_op)
344 {
345 return Argument{p_a,
346 p_b,
347 p_ds,
348 p_e,
349 MRaw,
350 NRaw,
351 KRaw,
352 StrideA,
353 StrideB,
354 StrideDs,
355 StrideE,
356 a_element_op,
357 b_element_op,
358 cde_element_op};
359 }
360
361 static auto MakeInvoker() { return Invoker{}; }
362
363 std::unique_ptr<BaseArgument>
364 MakeArgumentPointer(const void* p_a,
365 const void* p_b,
366 std::array<const void*, NumDTensor> p_ds,
367 void* p_e,
368 index_t MRaw,
369 index_t NRaw,
370 index_t KRaw,
371 index_t StrideA,
372 index_t StrideB,
373 std::array<ck::index_t, NumDTensor> StrideDs,
374 index_t StrideE,
375 AElementwiseOperation a_element_op,
376 BElementwiseOperation b_element_op,
377 CDEElementwiseOperation cde_element_op) override
378 {
379 return std::make_unique<Argument>(p_a,
380 p_b,
381 p_ds,
382 p_e,
383 MRaw,
384 NRaw,
385 KRaw,
386 StrideA,
387 StrideB,
388 StrideDs,
389 StrideE,
390 a_element_op,
391 b_element_op,
392 cde_element_op);
393 }
394
395 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
396 {
397 return std::make_unique<Invoker>(Invoker{});
398 }
399
400 std::string GetTypeString() const override
401 {
402 auto str = std::stringstream();
403
404 std::map<LoopScheduler, std::string> LoopSchedToString{
405 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
406
407 std::map<PipelineVersion, std::string> PipelineVersionToString{
409
410 // clang-format off
411 str << "DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad"
412 << "<"
413 << BlockSize << ", "
414 << MPerBlock << ", "
415 << NPerBlock << ", "
416 << KPerBlock << ", "
417 << AK1 << ", "
418 << BK1 << ", "
419 << MPerXDL << ", "
420 << NPerXDL << ", "
421 << MXdlPerWave << ", "
422 << NXdlPerWave << ", "
423 << ABlockTransferScalarPerVector << ", "
424 << BBlockTransferScalarPerVector << ", "
425 << CShuffleMXdlPerWavePerShuffle << ", "
426 << CShuffleNXdlPerWavePerShuffle << ", "
427 << getGemmSpecializationString(GemmSpec)
428 << ">"
429 << " LoopScheduler: "
430 << LoopSchedToString[LoopSched] << ", "
431 << "PipelineVersion: "
432 << PipelineVersionToString[PipelineVer];
433 // clang-format on
434
435 return str.str();
436 }
437};
438
439} // namespace device
440} // namespace tensor_operation
441} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
bool is_lds_direct_load_supported()
Definition host_utility/device_prop.hpp:101
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v4
Definition gridwise_gemm_pipeline_selector.hpp:22
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
__global__ void kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:43
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:149
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:134
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:137
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:204
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:76
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:78
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:364
static constexpr auto I1
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:80
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:395
typename GridwiseGemm64::Argument Argument
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:131
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:211
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:325
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:129
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:81
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:128
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:79
std::string GetTypeString() const override
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:400
GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferScalarPerVector, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer > GridwiseGemmBase
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:84
static auto MakeInvoker()
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:361
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:330
Definition device_gemm_multiple_d.hpp:36