device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp Source File

device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp Source File#

Composable Kernel: device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp Source File
device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename ALayout,
25 typename BLayout,
26 typename DsLayout,
27 typename CLayout,
28 typename ADataType,
29 typename BDataType,
30 typename DsDataType,
31 typename CDataType,
32 typename GemmAccDataType,
33 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
37 GemmSpecialization GemmSpec,
38 index_t BlockSize,
39 index_t MPerBlock,
40 index_t NPerBlock,
41 index_t KPerBlock,
42 index_t AK1,
43 index_t BK1,
44 index_t MPerXDL,
45 index_t NPerXDL,
46 index_t MXdlPerWave,
47 index_t NXdlPerWave,
48 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
49 typename ABlockTransferThreadClusterArrangeOrder,
50 typename ABlockTransferSrcAccessOrder,
51 index_t ABlockTransferSrcVectorDim,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t ABlockTransferDstScalarPerVector_AK1,
54 bool ABlockLdsExtraM,
55 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
56 typename BBlockTransferThreadClusterArrangeOrder,
57 typename BBlockTransferSrcAccessOrder,
58 index_t BBlockTransferSrcVectorDim,
59 index_t BBlockTransferSrcScalarPerVector,
60 index_t BBlockTransferDstScalarPerVector_BK1,
61 bool BBlockLdsExtraN,
62 index_t CShuffleMXdlPerWavePerShuffle,
63 index_t CShuffleNXdlPerWavePerShuffle,
64 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
65 typename CDEShuffleBlockTransferScalarPerVectors,
68 typename ComputeTypeA = CDataType,
69 typename ComputeTypeB = ComputeTypeA,
70 typename LDSTypeA = ComputeTypeA,
71 typename LDSTypeB = ComputeTypeB>
74 BLayout,
75 DsLayout,
76 CLayout,
77 ADataType,
78 BDataType,
79 DsDataType,
80 CDataType,
81 AElementwiseOperation,
82 BElementwiseOperation,
83 CElementwiseOperation>
84{
86 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
87 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
88 static constexpr index_t NumDTensor = DsDataType::Size();
89
90 // GridwiseGemm
91 template <index_t NXdlPerWave_>
93 ALayout,
94 BLayout,
95 DsLayout,
96 CLayout,
97 ADataType,
98 BDataType,
99 GemmAccDataType,
100 CShuffleDataType,
101 DsDataType,
102 CDataType,
103 AElementwiseOperation,
104 BElementwiseOperation,
105 CElementwiseOperation,
106 GemmSpec,
107 BlockSize,
108 MPerBlock,
109 NPerBlock,
110 KPerBlock,
111 AK1,
112 BK1,
113 MPerXDL,
114 NPerXDL,
115 MXdlPerWave,
116 NXdlPerWave_,
117 ABlockTransferThreadClusterLengths_AK0_M_AK1,
118 ABlockTransferThreadClusterArrangeOrder,
119 ABlockTransferSrcAccessOrder,
120 ABlockTransferSrcVectorDim,
121 ABlockTransferSrcScalarPerVector,
122 ABlockTransferDstScalarPerVector_AK1,
123 false,
124 ABlockLdsExtraM,
125 BBlockTransferThreadClusterLengths_BK0_N_BK1,
126 BBlockTransferThreadClusterArrangeOrder,
127 BBlockTransferSrcAccessOrder,
128 BBlockTransferSrcVectorDim,
129 BBlockTransferSrcScalarPerVector,
130 BBlockTransferDstScalarPerVector_BK1,
131 false,
132 BBlockLdsExtraN,
133 CShuffleMXdlPerWavePerShuffle,
134 CShuffleNXdlPerWavePerShuffle,
135 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
136 CDEShuffleBlockTransferScalarPerVectors,
137 BlkGemmPipeSched,
138 BlkGemmPipelineVer,
139 ComputeTypeA,
140 ComputeTypeB,
141 LDSTypeA,
142 LDSTypeB>;
145
146 using Argument = typename GridwiseGemm64::Argument;
147 int GetPreShuffleParameters() override { return NPerXDL; }
148
149 // Invoker
150 struct Invoker : public BaseInvoker
151 {
152 template <typename GridwiseGemm>
153 float RunImp(const typename GridwiseGemm::Argument& arg,
154 const StreamConfig& stream_config = StreamConfig{})
155 {
156 if(stream_config.log_level_ > 0)
157 {
158 arg.Print();
159 }
160
161 if(!GridwiseGemm::CheckValidity(arg))
162 {
163 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
164 }
165
166 index_t gdx, gdy, gdz;
167 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
168
169 float ave_time = 0;
170
171 index_t k_grain = arg.KBatch * KPerBlock;
172 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
173
174 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
175
176 const auto Run = [&](const auto& kernel) {
177 if(stream_config.flush_cache)
178 {
179
180 std::array<std::size_t, NumDTensor> DsSize;
181
182 auto arg_ = arg;
183
184 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
185 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
186 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
187 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
188
189 auto size_a_buffer =
190 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
191 auto size_b_buffer =
192 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
193
194 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
195 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
196
197 static_for<0, NumDTensor, 1>{}([&](auto i) {
198 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
199 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
200 });
201 ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
202 DsDataType>
203 rotating_mem(arg_,
204 stream_config.rotating_count,
205 size_a_buffer,
206 size_b_buffer,
207 DsSize);
208 rotating_mem.Print();
209
210 auto run_flush_cache = [&]() {
211 // flush icache
213 // rotating mem
214 rotating_mem.Next();
215 // clear c mem
216 if(arg_.KBatch > 1)
217 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
218 0,
219 arg_.M * arg_.N * sizeof(CDataType),
220 stream_config.stream_id_));
221 };
222
224 stream_config,
225 run_flush_cache,
226 kernel,
227 dim3(gdx, gdy, gdz),
228 dim3(BlockSize),
229 0,
230 arg_);
231 }
232 else
233 {
234 if(arg.KBatch > 1)
235 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
236 0,
237 arg.M * arg.N * sizeof(CDataType),
238 stream_config.stream_id_));
239
240 ave_time = launch_and_time_kernel(
241 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
242 }
243 };
244
245 constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
246 4 * (1 + GridwiseGemm::NWave);
247 constexpr auto estimated_reg_b =
248 NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2);
249 constexpr auto estimated_reg_c =
250 MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4;
251 constexpr auto estimated_reg_total =
252 estimated_reg_a + estimated_reg_b + estimated_reg_c;
253
254 constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
255
256 // static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
257 // has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
258 // now");
259 if(has_main_k_block_loop)
260 {
261 // Tail number always full
262 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
263 {
264 if(arg.KBatch > 1)
265 {
266 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
267 {
269 GridwiseGemm,
270 true,
272 minimum_occupancy,
274 Run(kernel);
275 }
276 else
277 {
279 GridwiseGemm,
280 true,
282 minimum_occupancy,
284 Run(kernel);
285 }
286 }
287 else
288 {
289 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
290 {
292 GridwiseGemm,
293 true,
295 minimum_occupancy,
297 Run(kernel);
298 }
299 else
300 {
302 GridwiseGemm,
303 true,
305 minimum_occupancy,
307 Run(kernel);
308 }
309 }
310 }
311 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
312 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
313 {
314 if(arg.KBatch > 1)
315 {
316 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
317 {
318 const auto kernel =
320 GridwiseGemm,
321 true,
323 minimum_occupancy,
325 Run(kernel);
326 }
327 else
328 {
329 const auto kernel =
331 GridwiseGemm,
332 true,
334 minimum_occupancy,
336 Run(kernel);
337 }
338 }
339 else
340 {
341 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
342 {
343 const auto kernel =
345 GridwiseGemm,
346 true,
348 minimum_occupancy,
350 Run(kernel);
351 }
352 else
353 {
354 const auto kernel =
356 GridwiseGemm,
357 true,
359 minimum_occupancy,
361 Run(kernel);
362 }
363 }
364 }
365 else
366 {
367 throw std::runtime_error("todo: only v1 v2 and v3 support now");
368 }
369 }
370#if 0
371 else
372 {
373 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
374 {
375#if 0
376 if(arg.KBatch > 1)
377 {
378 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
379 {
381 GridwiseGemm,
382 false,
384 minimum_occupancy,
386 Run(kernel);
387 }
388 else
389 {
391 GridwiseGemm,
392 false,
394 minimum_occupancy,
396 Run(kernel);
397 }
398 }
399 else
400 {
401 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
402 {
404 GridwiseGemm,
405 false,
407 minimum_occupancy,
409 Run(kernel);
410 }
411 else
412 {
414 GridwiseGemm,
415 false,
417 minimum_occupancy,
419 Run(kernel);
420 }
421 }
422#endif
423 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
424 {
426 GridwiseGemm,
427 false,
429 minimum_occupancy,
431 Run(kernel);
432 }
433 else
434 {
436 GridwiseGemm,
437 false,
439 minimum_occupancy,
441 Run(kernel);
442 }
443 }
444 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
445 {
446 if(arg.KBatch > 1)
447 {
448 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
449 {
450 const auto kernel =
452 GridwiseGemm,
453 false,
455 minimum_occupancy,
457 Run(kernel);
458 }
459 else
460 {
461 const auto kernel =
463 GridwiseGemm,
464 false,
466 minimum_occupancy,
468 Run(kernel);
469 }
470 }
471 else
472 {
473 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
474 {
475 const auto kernel =
477 GridwiseGemm,
478 false,
480 minimum_occupancy,
482 Run(kernel);
483 }
484 else
485 {
486 const auto kernel =
488 GridwiseGemm,
489 false,
491 minimum_occupancy,
493 Run(kernel);
494 }
495 }
496 }
497 else
498 {
499 throw std::runtime_error("todo: only v3 support now");
500 }
501 }
502#endif
503
504 return ave_time;
505 }
506
508
509 // polymorphic
510 float Run(const BaseArgument* p_arg,
511 const StreamConfig& stream_config = StreamConfig{}) override
512 {
513 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
514 }
515 };
516
517 static constexpr bool IsValidCompilationParameter()
518 {
519 // TODO: properly implement this check
520 return true;
521 }
522
523 static bool IsSupportedArgument(const Argument& arg)
524 {
526 {
527 return false;
528 }
529 if(is_gfx11_supported() && arg.KBatch > 1)
530 {
531 return false;
532 }
533 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
534 {
535 return false;
536 }
537
538 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
539 GemmSpec == GemmSpecialization::NKPadding ||
540 GemmSpec == GemmSpecialization::MNKPadding ||
541 GemmSpec == GemmSpecialization::KPadding))
542 {
543 return false;
544 }
545
546 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
547 {
548 return false;
549 }
550
551 if(get_warp_size() == 64)
552 {
553 if constexpr(NXdlPerWave64 > 0)
554 {
556 }
557 }
558 else
559 {
560 if constexpr(NXdlPerWave32 > 0)
561 {
563 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
564 }
565 }
566 return false;
567 }
568
569 // polymorphic
570 bool IsSupportedArgument(const BaseArgument* p_arg) override
571 {
572 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
573 }
574
575 static auto MakeArgument(const void* p_a,
576 const void* p_b,
577 std::array<const void*, NumDTensor> p_ds,
578 void* p_c,
579 index_t M,
580 index_t N,
581 index_t K,
582 index_t StrideA,
583 index_t StrideB,
584 std::array<index_t, NumDTensor> StrideDs,
585 index_t StrideC,
586 index_t KBatch,
587 AElementwiseOperation a_element_op,
588 BElementwiseOperation b_element_op,
589 CElementwiseOperation c_element_op)
590 {
591 return Argument{static_cast<const ADataType*>(p_a),
592 static_cast<const BDataType*>(p_b),
593 p_ds,
594 static_cast<CDataType*>(p_c),
595 M,
596 N,
597 K,
598 StrideA,
599 StrideB,
600 StrideDs,
601 StrideC,
602 KBatch,
603 a_element_op,
604 b_element_op,
605 c_element_op};
606 }
607
608 static auto MakeInvoker() { return Invoker{}; }
609
610 // polymorphic
611 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
612 const void* p_b,
613 std::array<const void*, NumDTensor> p_ds,
614 void* p_c,
615 index_t M,
616 index_t N,
617 index_t K,
618 index_t StrideA,
619 index_t StrideB,
620 std::array<ck::index_t, NumDTensor> StrideDs,
621 index_t StrideC,
622 index_t KBatch,
623 AElementwiseOperation a_element_op,
624 BElementwiseOperation b_element_op,
625 CElementwiseOperation c_element_op) override
626 {
627 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
628 static_cast<const BDataType*>(p_b),
629 p_ds,
630 static_cast<CDataType*>(p_c),
631 M,
632 N,
633 K,
634 StrideA,
635 StrideB,
636 StrideDs,
637 StrideC,
638 KBatch,
639 a_element_op,
640 b_element_op,
641 c_element_op);
642 }
643
644 // polymorphic
645 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
646 {
647 return std::make_unique<Invoker>(Invoker{});
648 }
649
650 // polymorphic
651 std::string GetTypeString() const override
652 {
653 auto str = std::stringstream();
654
655 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
658
659 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
663
664 // clang-format off
665 str << "DeviceGemmXdlUniversal"
666 << "<"
667 << getGemmSpecializationString(GemmSpec) << ", "
668 << std::string(ALayout::name)[0]
669 << std::string(BLayout::name)[0]
670 << std::string(CLayout::name)[0]
671 << ">"
672 << " BlkSize: "
673 << BlockSize << ", "
674 << "BlkTile: "
675 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
676 << "WaveTile: "
677 << MPerXDL<<"x"<<NPerXDL << ", "
678 << "WaveMap: "
679 << MXdlPerWave<<"x" << NXdlPerWave<<", "
680 << "VmemReadVec: "
681 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
682 << "BlkGemmPipelineScheduler: "
683 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
684 << "BlkGemmPipelineVersion: "
685 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
686 << "BlkGemmPipelinePrefetchStages: "
687 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
688 // clang-format on
689
690 return str.str();
691 }
692};
693
694} // namespace device
695} // namespace tensor_operation
696} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:39
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:82
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:168
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:151
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:510
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:153
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:84
std::string GetTypeString() const override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:651
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:144
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:611
GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:92
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:575
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:88
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:570
typename GridwiseGemm64::Argument Argument
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:146
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:523
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:517
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:87
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:143
int GetPreShuffleParameters() override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:147
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:86
static auto MakeInvoker()
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:608
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:645
Definition flush_cache.hpp:174