gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp Source File

gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp Source File
gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
18
19#define DEBUG_LOG 0
20
21namespace ck {
22
23// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24// kernel function Blockers:
25// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26// two lds chunks.
27// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28// buffer when we declare __shared__ inside blkgemmpipe
29template <typename GridwiseGemm,
30 bool HasMainKBlockLoop,
31 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
32 index_t MinimumOccupancy = 1,
34__global__ void
35#if CK_USE_LAUNCH_BOUNDS
36__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
37#endif
38 // __attribute__((amdgpu_waves_per_eu(1, 1)))
39 kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
40{
41#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
42 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
43 {
44 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
45
46 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
47 karg.p_a_grid,
48 karg.p_b_grid,
49 karg.p_ds_grid,
50 karg.p_c_grid,
51 karg.p_a_scale_grid,
52 karg.p_b_scale_grid,
53 p_shared,
54 karg,
55 karg.a_element_op,
56 karg.b_element_op,
57 karg.c_element_op);
58 }
59#else
60 ignore = karg;
61#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
62}
63
64template <typename ALayout,
65 typename BLayout,
66 typename DsLayout,
67 typename CLayout,
68 typename ADataType,
69 typename BDataType,
70 typename AccDataType,
71 typename CShuffleDataType,
72 typename DsDataType,
73 typename CDataType,
74 typename AElementwiseOperation,
75 typename BElementwiseOperation,
76 typename CElementwiseOperation,
78 index_t BlockSize,
79 index_t ScaleBlockM,
80 index_t ScaleBlockN,
81 index_t ScaleBlockK,
82 index_t MPerBlock,
83 index_t NPerBlock,
84 index_t KPerBlock,
85 index_t AK1Value,
86 index_t BK1Value,
87 index_t MPerXdl,
88 index_t NPerXdl,
89 index_t MXdlPerWave,
90 index_t NXdlPerWave,
91 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
92 typename ABlockTransferThreadClusterArrangeOrder,
93 typename ABlockTransferSrcAccessOrder,
94 index_t ABlockTransferSrcVectorDim,
95 index_t ABlockTransferSrcScalarPerVector,
96 index_t ABlockTransferDstScalarPerVector_AK1,
97 bool AThreadTransferSrcResetCoordinateAfterRun,
98 index_t ABlockLdsExtraM,
99 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
100 typename BBlockTransferThreadClusterArrangeOrder,
101 typename BBlockTransferSrcAccessOrder,
102 index_t BBlockTransferSrcVectorDim,
103 index_t BBlockTransferSrcScalarPerVector,
104 index_t BBlockTransferDstScalarPerVector_BK1,
105 bool BThreadTransferSrcResetCoordinateAfterRun,
106 index_t BBlockLdsExtraN,
107 index_t CShuffleMXdlPerWavePerShuffle,
108 index_t CShuffleNXdlPerWavePerShuffle,
109 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
110 typename CDEShuffleBlockTransferScalarPerVectors,
113 typename ComputeTypeA = CDataType,
114 typename ComputeTypeB = ComputeTypeA,
115 typename LDSTypeA = ADataType,
116 typename LDSTypeB = BDataType>
118{
119 using AScaleType = float;
120 using BScaleType = float;
121
122 static constexpr auto I0 = Number<0>{};
123 static constexpr auto I1 = Number<1>{};
124 static constexpr auto I2 = Number<2>{};
125 static constexpr auto I3 = Number<3>{};
126 static constexpr auto I4 = Number<4>{};
127 static constexpr auto I5 = Number<5>{};
128 static constexpr auto I6 = Number<6>{};
129 static constexpr auto I7 = Number<7>{};
130
132 CDEShuffleBlockTransferScalarPerVectors{}[I0];
133
134 // K1 should be Number<...>
135 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
136 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
137 static constexpr auto AK1Number = Number<AK1Value>{};
138 static constexpr auto BK1Number = Number<BK1Value>{};
139
140 static constexpr index_t NumDTensor = DsDataType::Size();
141
142 static constexpr auto MakeDsGridPointer()
143 {
144 return generate_tuple(
145 [&](auto i) {
146 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
147
148 return static_cast<const DDataType*>(nullptr);
149 },
151 }
152
153 using DsGridPointer = decltype(MakeDsGridPointer());
154
155 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
156 static constexpr bool is_single_rate_mfma =
158 lcm_AK1_BK1 <= 4) ||
160 ? true
161 : false;
162 static constexpr auto is_scale_mfma = false;
163 static constexpr index_t KPack =
165 MfmaSelector<ComputeTypeA,
166 MPerXdl,
167 NPerXdl,
168 ComputeTypeB,
170 is_scale_mfma>::selected_mfma.k_per_blk);
171
173
174 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
175 {
176 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
177 }
178
179 __host__ static auto CalculateMPadded(index_t M)
180 {
181 return math::integer_least_multiple(M, MPerBlock);
182 }
183
184 __host__ static auto CalculateNPadded(index_t N)
185 {
186 return math::integer_least_multiple(N, NPerBlock);
187 }
188
189 __host__ static auto CalculateKPadded(index_t K)
190 {
191 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
192 }
193
194 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
195 {
196 auto K_t = K_Batch * KPerBlock;
197 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
198 }
199
200 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
201 {
202 auto K_t = K_Batch * KPerBlock;
203 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
204 }
205
206 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
207 {
208 auto K_t = K_Batch * KPerBlock;
209 return (K + K_t - 1) / K_t * KPerBlock;
210 }
211
212 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
213 {
214 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
215 auto K_t = K_Batch * KReadVec;
216 return (K + K_t - 1) / K_t * KReadVec;
217 }
218
219 __host__ static auto CalculateMBlock(index_t M)
220 {
221 return math::integer_divide_ceil(M, MPerBlock);
222 }
223
224 __host__ static auto CalculateNBlock(index_t N)
225 {
226 return math::integer_divide_ceil(N, NPerBlock);
227 }
228
229 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
230 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
231 {
232 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
233 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
234
236 TileDesc_K0_MN_K1{},
242 }
243
244 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
245 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
246 {
247 const auto a_grid_desc_mraw_kraw = [&]() {
249 {
250 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
251 }
253 {
254 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
255 }
256 }();
257
258 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
259
260 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
261 GemmSpec == GemmSpecialization::MNKPadding)
262 {
263 // pad both M and K
264 const auto a_grid_desc_m_k =
265 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
267 make_right_pad_transform(K, KPad - K)),
270
271 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
272 a_grid_desc_m_k,
277
278 return a_grid_desc_ak0_m_ak1;
279 }
280 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
281 GemmSpec == GemmSpecialization::MNPadding)
282 {
283 // pad M, but not K
284 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
285 a_grid_desc_mraw_kraw,
287 make_right_pad_transform(M, MPad - M)),
290
291 return a_grid_desc_ak0_m_ak1;
292 }
293 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
294 GemmSpec == GemmSpecialization::NKPadding)
295 {
296 // pad K, but not M
297 const auto a_grid_desc_m_k = transform_tensor_descriptor(
298 a_grid_desc_mraw_kraw,
302
303 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
304 a_grid_desc_m_k,
309
310 return a_grid_desc_ak0_m_ak1;
311 }
312 else
313 {
314 // not pad M or K
315 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
316 a_grid_desc_mraw_kraw,
321
322 return a_grid_desc_ak0_m_ak1;
323 }
324 }
325
326 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
327 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
328 {
329 const auto b_grid_desc_nraw_kraw = [&]() {
331 {
332 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
333 }
335 {
336 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
337 }
338 }();
339
340 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
341
342 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
343 GemmSpec == GemmSpecialization::MNKPadding)
344 {
345 // pad both N and K
346 const auto b_grid_desc_n_k =
347 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
349 make_right_pad_transform(K, KPad - K)),
352
353 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
354 b_grid_desc_n_k,
359
360 return b_grid_desc_bk0_n_bk1;
361 }
362 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
363 GemmSpec == GemmSpecialization::MNPadding)
364 {
365 // pad N, but not K
366 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
367 b_grid_desc_nraw_kraw,
369 make_right_pad_transform(N, NPad - N)),
372
373 return b_grid_desc_bk0_n_bk1;
374 }
375 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
376 GemmSpec == GemmSpecialization::MKPadding)
377 {
378 // pad K, but not N
379 const auto b_grid_desc_n_k = transform_tensor_descriptor(
380 b_grid_desc_nraw_kraw,
384
385 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
386 b_grid_desc_n_k,
391
392 return b_grid_desc_bk0_n_bk1;
393 }
394 else
395 {
396 // not pad N or K
397 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
398 b_grid_desc_nraw_kraw,
403
404 return b_grid_desc_bk0_n_bk1;
405 }
406 }
407
408 __host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K)
409 {
410 const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
411 const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
413 {
415 }
417 {
419 }
420 }
421
422 __host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K)
423 {
424 const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
425 const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
427 {
429 }
431 {
433 }
434 }
435
436 template <typename ABlockDesc_AK0_M_AK1>
437 __host__ __device__ static constexpr auto
438 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
439 {
440 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
441
442 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
443 }
444
445 template <typename BBlockDesc_BK0_N_BK1>
446 __host__ __device__ static constexpr auto
447 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
448 {
449 constexpr index_t NWaves =
450 NXdlPerWave * NPerXdl == 0 ? 1 : NPerBlock / (NXdlPerWave * NPerXdl);
451
452 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
453 }
454
455 template <typename ELayout>
456 __host__ __device__ static auto
458 {
459 const auto c_grid_desc_mraw_nraw = [&]() {
461 {
462 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
463 }
465 {
466 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
467 }
468 }();
469
470 // pad M and N
471 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
473 make_right_pad_transform(N, NPad - N)),
476#if 0
477 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
478
479 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
480 GemmSpec == GemmSpecialization::MNKPadding)
481 {
482 // pad M and N
483 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
485 make_right_pad_transform(N, NPad - N)),
488 }
489 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
490 GemmSpec == GemmSpecialization::MKPadding)
491 {
492 // pad M, but not N
494 c_grid_desc_mraw_nraw,
498 }
499 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
500 GemmSpec == GemmSpecialization::NKPadding)
501 {
502 // pad N, but not M
504 c_grid_desc_mraw_nraw,
508 }
509 else
510 {
511 // not pad M or N
512 return c_grid_desc_mraw_nraw;
513 }
514#endif
515 }
516
517 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
518 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
519 {
520 return generate_tuple(
521 [&](auto i) {
522 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
523 return MakeCGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
524 },
526 }
527
528 template <typename DsGridDesc>
530 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
531 {
532 return generate_tuple(
533 [&](auto i) {
535 ds_grid_desc_m_n[i], MBlock, NBlock);
536 },
538 }
539
540 using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
541
542 struct Problem
543 {
544 __host__ Problem(index_t M_,
545 index_t N_,
546 index_t K_,
547 index_t StrideA_,
548 index_t StrideB_,
549 std::array<index_t, NumDTensor> StrideDs_,
550 index_t StrideC_,
551 index_t KBatch_)
552 : M{M_},
553 N{N_},
554 K{K_},
555 StrideA{StrideA_},
556 StrideB{StrideB_},
557 StrideDs{StrideDs_},
558 StrideC{StrideC_},
559 KBatch{KBatch_},
562 KRead{CalculateKRead(K_, KBatch_)},
563 KPadded{CalculateKPadded(K_, KBatch_)},
564 AK0{CalculateAK0Padded(K_, KBatch_)},
565 BK0{CalculateBK0Padded(K_, KBatch_)},
568 {
569 }
570
571 __host__ void Print() const
572 {
573 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
574 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
575 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
576 << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
577 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
578 << "NBlock: " << NBlock << "}" << std::endl;
579 }
580
586 std::array<index_t, NumDTensor> StrideDs;
588
598 };
599
600 // Argument
602 {
603 __host__ Argument(const ADataType* p_a_grid_,
604 const BDataType* p_b_grid_,
605 std::array<const void*, NumDTensor> p_ds_grid_,
606 CDataType* p_c_grid_,
607 index_t M_,
608 index_t N_,
609 index_t K_,
610 index_t StrideA_,
611 index_t StrideB_,
612 std::array<index_t, NumDTensor> StrideDs_,
613 index_t StrideC_,
614 const AScaleType* p_a_scale_grid_,
615 const BScaleType* p_b_scale_grid_,
616 index_t k_batch_,
617 AElementwiseOperation a_element_op_,
618 BElementwiseOperation b_element_op_,
619 CElementwiseOperation c_element_op_)
620 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
621 p_a_grid{p_a_grid_},
622 p_b_grid{p_b_grid_},
623 p_ds_grid{},
624 p_c_grid{p_c_grid_},
625 p_a_scale_grid{p_a_scale_grid_},
626 p_b_scale_grid{p_b_scale_grid_},
627 a_element_op{a_element_op_},
628 b_element_op{b_element_op_},
629 c_element_op{c_element_op_}
630 {
631
632 // populate pointer, desc for Ds
633 static_for<0, NumDTensor, 1>{}([&](auto i) {
634 using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
635
636 // D pointer
637 p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
638 });
639 }
640
641 const ADataType* p_a_grid;
642 const BDataType* p_b_grid;
644 CDataType* p_c_grid;
645
648
649 const AElementwiseOperation a_element_op;
650 const BElementwiseOperation b_element_op;
651 const CElementwiseOperation c_element_op;
652 };
653
655 {
656 __device__ SplitKBatchOffset(Argument& karg)
657 {
659 {
660 a_k_split_offset = blockIdx.z * karg.KRead;
661 }
663 {
664 a_k_split_offset = blockIdx.z * karg.KRead * karg.M;
665 }
666
668 {
669 b_k_split_offset = blockIdx.z * karg.KRead * karg.N;
670 }
672 {
673 b_k_split_offset = blockIdx.z * karg.KRead;
674 }
675
676 if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
677 {
678 karg.K = karg.KRead;
679 }
680 else
681 {
682 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
683 }
684 }
685
688 };
689
690 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
691 {
692 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
693 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
694 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
695 // A matrix in LDS memory, dst of blockwise copy
696 if constexpr(ABlockLdsExtraM)
697 {
701 }
702 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
703 // in some cases.
705 {
706 constexpr auto a_lds_block_desc =
709
710 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
711 a_lds_block_desc,
717
718 return a_lds_block_desc_permuted;
719 }
720 else // ColumnMajor A
721 {
722 // kfold and mpair dimension is not always required.
723 // more dimension in merge_transform increase the difficulty of generating immarg offset
724 // for compiler.
725 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
726 constexpr auto M1 = MPerBlock / M0;
727
728 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
729 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
730 constexpr auto KThreadRead = WaveSize / MPerXdl;
731 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
732
733 constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
734 ? 1
735 : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
736 constexpr auto KThreadReadPerm =
737 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
738 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
739 : KThreadRead;
740
741 // 1<=mpair<=n0
742 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
743 ? 1
744 : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
745 ? M0
746 : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
747
748 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
752 Number<kfold * M0 / mpair>{},
754 AK1Number));
755
756 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
757 a_lds_block_desc,
762 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
769
770 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
771 a_lds_block_desc_permuted,
780 Sequence<1>{},
781 Sequence<2>{},
782 Sequence<3>{},
783 Sequence<4>{},
784 Sequence<5>{}),
786 Sequence<2>{},
789 Sequence<6>{},
790 Sequence<7>{}));
791
792 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
793 a_lds_block_desc_unmerged,
796 Number<KThreadWrite / kfold / KThreadReadPerm>{},
804
805 return a_lds_block_desc_ak0_m_ak1;
806 }
807 }
808
809 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
810 {
811 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
812 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
813 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
814 // B matrix in LDS memory, dst of blockwise copy
815 if constexpr(BBlockLdsExtraN)
816 {
820 }
822 {
823 constexpr auto b_lds_block_desc =
826
827 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
828 b_lds_block_desc,
834
835 return b_lds_block_desc_permuted;
836 }
837 else // RowMajor B
838 {
839 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
840 constexpr auto N1 = NPerBlock / N0;
841
842 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
843 constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
844 constexpr auto KThreadRead = WaveSize / NPerXdl;
845 constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
846
847 constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128)
848 ? 1
849 : 128 / (BK1Number * N0 * sizeof(LDSTypeB));
850 constexpr auto KThreadReadPerm =
851 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
852 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
853 : KThreadRead;
854
855 // 1<=npair<=n0
856 constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128)
857 ? 1
858 : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0
859 ? N0
860 : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB)));
861
862 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
866 Number<kfold * N0 / npair>{},
868 BK1Number));
869
870 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
871 b_lds_block_desc,
876 make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
883
884 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
885 b_lds_block_desc_permuted,
894 Sequence<1>{},
895 Sequence<2>{},
896 Sequence<3>{},
897 Sequence<4>{},
898 Sequence<5>{}),
900 Sequence<2>{},
903 Sequence<6>{},
904 Sequence<7>{}));
905
906 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
907 b_lds_block_desc_unmerged,
910 Number<KThreadWrite / kfold / KThreadReadPerm>{},
918
919 return b_lds_block_desc_bk0_n_bk1;
920 }
921 }
922
924 {
925 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
926 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
927
928 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
932 I1,
934
935 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
936 }
937
940 BlkGemmPipelineVer,
941 BlkGemmPipeSched,
942 BlockSize,
943 LDSTypeA,
944 LDSTypeB,
945 ComputeTypeA,
946 AccDataType,
953 ABlockTransferSrcScalarPerVector,
954 BBlockTransferSrcScalarPerVector,
955 MPerBlock,
956 NPerBlock,
957 KPerBlock,
958 MPerXdl,
959 NPerXdl,
960 MXdlPerWave,
961 NXdlPerWave,
962 KPack>())>;
963
964 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
965 {
966 // LDS allocation for A and B: be careful of alignment
967 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
968 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
969
970 // lds max alignment
971 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
972
973 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
974 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
975
976 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
977 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
978
979 // LDS allocation for C shuffle in LDS
980 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
982
983 constexpr auto c_block_size =
984 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
985
986 return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) +
987 b_block_space_size_aligned * sizeof(LDSTypeB)),
988 c_block_size * sizeof(CShuffleDataType));
989 }
990
992
993 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
994 __host__ static constexpr bool CheckValidity(const Argument& karg)
995 {
996 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
997 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
998 "Invalid tuning param!");
999
1000 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1005 {
1006 if(!(karg.M % MPerBlock == 0))
1007 {
1008#if DEBUG_LOG
1009 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1010 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1011 << std::endl;
1012
1013#endif // DEBUG_LOG
1014 return false;
1015 }
1016 }
1017
1018 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1023 {
1024 if(!(karg.N % NPerBlock == 0))
1025 {
1026#if DEBUG_LOG
1027 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1028 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1029 << std::endl;
1030
1031#endif // DEBUG_LOG
1032 return false;
1033 }
1034 }
1035
1036 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1040 {
1041
1042 auto K_t = karg.KBatch * KPerBlock;
1043 if(!(karg.K % K_t == 0))
1044 {
1045#if DEBUG_LOG
1046 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1047 << karg.K << " " << __FILE__ << ":" << __LINE__
1048 << ", in function: " << __func__ << std::endl;
1049
1050#endif // DEBUG_LOG
1051 return false;
1052 }
1053 }
1054 else
1055 {
1056 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1057 auto K_t = karg.KBatch * KReadVec;
1058 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1059 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1060 {
1061 return false;
1062 }
1063 }
1064
1066 {
1067 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1068 {
1069#if DEBUG_LOG
1070 std::cout << "Arg K (" << karg.K
1071 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1072 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1073 << __LINE__ << ", in function: " << __func__ << std::endl;
1074
1075#endif // DEBUG_LOG
1076 return false;
1077 }
1078 }
1079 else
1080 {
1081 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1082 {
1083#if DEBUG_LOG
1084 std::cout << "Arg M (" << karg.M
1085 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1086 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1087 << __LINE__ << ", in function: " << __func__ << std::endl;
1088
1089#endif // DEBUG_LOG
1090 return false;
1091 }
1092 }
1093
1095 {
1096 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1097 {
1098#if DEBUG_LOG
1099 std::cout << "Arg N (" << karg.N
1100 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1101 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1102 << __LINE__ << ", in function: " << __func__ << std::endl;
1103
1104#endif // DEBUG_LOG
1105 return false;
1106 }
1107 }
1108 else
1109 {
1110 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1111 {
1112#if DEBUG_LOG
1113 std::cout << "Arg K (" << karg.K
1114 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1115 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1116 << __LINE__ << ", in function: " << __func__ << std::endl;
1117
1118#endif // DEBUG_LOG
1119 return false;
1120 }
1121 }
1122
1124 {
1126 {
1127#if DEBUG_LOG
1128 std::cout << "Arg N (" << karg.N
1129 << ") value is not a multiple of "
1130 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1131 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1132 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1133
1134#endif // DEBUG_LOG
1135 return false;
1136 }
1137 }
1138 else
1139 {
1141 {
1142#if DEBUG_LOG
1143 std::cout << "Arg M (" << karg.M
1144 << ") value is not a multiple of "
1145 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1146 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1147 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1148
1149#endif // DEBUG_LOG
1150 return false;
1151 }
1152 }
1153
1154 // check gridwise gemm pipeline
1155 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1156
1157 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1158 {
1159 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1160 {
1161 return false;
1162 }
1163 }
1164
1165 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1166 return true;
1167 }
1168
1169 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1170 {
1171 const index_t num_loop = K / KPerBlock;
1172
1173 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1174 }
1175
1176 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1177 {
1178 const index_t num_loop = K / KPerBlock;
1179
1180 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1181 }
1182
1183 template <typename CGridDesc>
1185 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1186 {
1187 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1188 c_grid_desc_m_n,
1193
1194 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1195 }
1196
1197 // return block_id to C matrix tile idx (m0, n0) mapping
1198 // if arch = gfx942
1200 // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1201
1202 template <bool HasMainKBlockLoop,
1203 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1204 TailNumber TailNum = TailNumber::Odd>
1205 __device__ static void Run(const ADataType* p_a_grid,
1206 const BDataType* p_b_grid,
1207 DsGridPointer& p_ds_grid,
1208 CDataType* p_c_grid,
1209 const AScaleType* p_a_scale_grid,
1210 const BScaleType* p_b_scale_grid,
1211 void* p_shared,
1212 const Problem& problem,
1213 AElementwiseOperation a_element_op,
1214 BElementwiseOperation b_element_op,
1215 CElementwiseOperation c_element_op)
1216 {
1217 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1218 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1219 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1220 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1221 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1222 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1223
1224 const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K);
1225 const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K);
1226
1227 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1229 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1230
1231 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1232 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1233 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1234 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1236 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1237
1238 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1239 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1240
1241 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1242 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1243
1244 // divide block work by [M, N]
1245 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1246
1247 const auto block_work_idx =
1248 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1249
1250 if(!block_2_ctile_map.ValidCTileIndex(
1251 block_work_idx,
1252 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1253 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1254 {
1255 return;
1256 }
1257
1258 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1259 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1260
1261 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1262 const index_t m_block_data_idx_on_grid =
1263 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1264
1265 const index_t n_block_data_idx_on_grid =
1266 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1267
1268 // lds max alignment
1269 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1270
1271 // A matrix in LDS memory, dst of blockwise copy
1272 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1273
1274 // B matrix in LDS memory, dst of blockwise copy
1275 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1276
1277 // A matrix blockwise copy
1278 auto a_blockwise_copy =
1280 AElementwiseOperation,
1284 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1285 ABlockTransferThreadClusterArrangeOrder,
1286 ADataType,
1287 LDSTypeA,
1288 decltype(a_grid_desc_ak0_m_ak1),
1289 decltype(a_block_desc_ak0_m_ak1),
1290 ABlockTransferSrcAccessOrder,
1292 ABlockTransferSrcVectorDim,
1293 2,
1294 ABlockTransferSrcScalarPerVector,
1295 ABlockTransferDstScalarPerVector_AK1,
1296 1,
1297 1,
1298 AThreadTransferSrcResetCoordinateAfterRun,
1299 true,
1300 BlockwiseGemmPipe::GlobalBufferNum>(
1301 a_grid_desc_ak0_m_ak1,
1302 make_multi_index(0, m_block_data_idx_on_grid, 0),
1303 a_element_op,
1304 a_block_desc_ak0_m_ak1,
1305 make_multi_index(0, 0, 0),
1307
1308 // B matrix blockwise copy
1309 auto b_blockwise_copy =
1311 BElementwiseOperation,
1315 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1316 BBlockTransferThreadClusterArrangeOrder,
1317 BDataType,
1318 LDSTypeB,
1319 decltype(b_grid_desc_bk0_n_bk1),
1320 decltype(b_block_desc_bk0_n_bk1),
1321 BBlockTransferSrcAccessOrder,
1323 BBlockTransferSrcVectorDim,
1324 2,
1325 BBlockTransferSrcScalarPerVector,
1326 BBlockTransferDstScalarPerVector_BK1,
1327 1,
1328 1,
1329 BThreadTransferSrcResetCoordinateAfterRun,
1330 true,
1331 BlockwiseGemmPipe::GlobalBufferNum>(
1332 b_grid_desc_bk0_n_bk1,
1333 make_multi_index(0, n_block_data_idx_on_grid, 0),
1334 b_element_op,
1335 b_block_desc_bk0_n_bk1,
1336 make_multi_index(0, 0, 0),
1338
1339 // LDS allocation for A and B: be careful of alignment
1340 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1341 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1342
1343 // Cast after lds
1345 static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1346
1348 static_cast<LDSTypeB*>(p_shared) +
1349 a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
1350 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1351
1352 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1353 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1354
1355 // Blockwise GEMM pipeline
1356 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1357 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1358 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1359
1360 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1361 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1362 KPerBlock);
1363
1364 constexpr index_t ScaleSliceSizeM = MXdlPerWave;
1365 constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
1366 constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
1367
1368 // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
1369 // ScaleSliceSizeK is first dimension in C scale for packed math
1370 constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1372
1373 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1374 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1375 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1376 auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
1377 (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
1378
1379 constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1381
1382 constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
1384
1385 auto a_scale_thread_copy =
1387 AScaleType,
1388 decltype(a_scale_grid_desc_am_ak),
1389 decltype(a_scale_thread_desc),
1392 1,
1393 ScaleSliceSizeK,
1394 1,
1395 false>(
1396 a_scale_grid_desc_am_ak,
1397 make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0));
1398
1399 auto b_scale_thread_copy =
1401 BScaleType,
1402 decltype(b_scale_grid_desc_bn_ak),
1403 decltype(b_scale_thread_desc),
1406 1,
1407 ScaleSliceSizeK,
1408 1,
1409 false>(
1410 b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1411
1412 // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
1413 constexpr auto a_scale_thread_slice_copy_step =
1414 make_tuple(make_multi_index(MWaves * MPerXdl, 0),
1415 make_multi_index(-MPerBlock, 0),
1416 make_multi_index(-MPerBlock, ScaleSliceSizeK));
1417 constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
1418
1419 constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
1420
1421 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1422 a_grid_desc_ak0_m_ak1,
1423 a_block_desc_ak0_m_ak1,
1424 a_blockwise_copy,
1425 a_grid_buf,
1426 a_block_buf,
1427 a_block_slice_copy_step,
1428 b_grid_desc_bk0_n_bk1,
1429 b_block_desc_bk0_n_bk1,
1430 b_blockwise_copy,
1431 b_grid_buf,
1432 b_block_buf,
1433 b_block_slice_copy_step,
1434
1435 c_scale_thread_desc,
1436 c_thread_buf,
1437
1438 a_scale_grid_desc_am_ak,
1439 a_scale_thread_desc,
1440 a_scale_thread_copy,
1441 a_scale_grid_buf,
1442 a_scale_thread_slice_copy_step,
1443
1444 b_scale_grid_desc_bn_ak,
1445 b_scale_thread_desc,
1446 b_scale_thread_copy,
1447 b_scale_grid_buf,
1448 b_scale_thread_slice_copy_step,
1449
1450 num_k_block_main_loop);
1451
1452 // shuffle C and write out
1453 {
1454 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1455 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1456 "wrong!");
1457
1458 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1459 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1460
1461 // transposed XDL
1462 // // TODO: hacky, fix it!
1463 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1464 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1465
1466 // // TODO: hacky, fix it!
1467 // only used to get lengths
1468 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1469 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1470
1471 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1472 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1473 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1474 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1475 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1476 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1477 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1478 constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1479
1480 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1482
1483 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1484 static_cast<CShuffleDataType*>(p_shared),
1485 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1486
1487 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1488 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1489 make_tuple(
1492 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1493 M1, // M1 = MWave
1494 M2)), // M2 = MPerXdl
1497 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1498 N1, // N1 = NWave
1499 N2, // N2 * N3 * N4 = NPerXdl
1500 N3,
1501 N4))),
1503 make_tuple(
1505
1506 // calculate origin of thread output tensor on global memory
1507 // blockwise GEMM c matrix starting index
1508 const auto c_thread_mtx_on_block =
1509 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1510
1511 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1512 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1513
1514 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1519
1520 const auto m_thread_data_on_block_idx =
1521 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1522 make_multi_index(m_thread_data_on_block));
1523
1524 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1526 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1529
1530 const auto n_thread_data_on_block_idx =
1531 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1532 make_multi_index(n_thread_data_on_block));
1533
1534 // shuffle: threadwise copy C from VGPR to LDS
1535 auto c_thread_copy_vgpr_to_lds =
1537 CShuffleDataType,
1538 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1539 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1541 Sequence<CShuffleMXdlPerWavePerShuffle,
1542 CShuffleNXdlPerWavePerShuffle,
1543 I1,
1544 I1,
1545 I1,
1546 N2,
1547 I1,
1548 N4>,
1550 7,
1551 1,
1553 1,
1554 true>{
1555 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1557 0,
1558 m_thread_data_on_block_idx[I1],
1559 n_thread_data_on_block_idx[I1],
1560 m_thread_data_on_block_idx[I2],
1561 n_thread_data_on_block_idx[I2],
1562 n_thread_data_on_block_idx[I3],
1563 n_thread_data_on_block_idx[I4]),
1565
1566 using EDataType = CDataType;
1567
1568 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1569 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1570
1571 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1573 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1574
1575 const auto ds_grid_buf = generate_tuple(
1576 [&](auto i) {
1578 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1579 },
1581
1582 // tuple of reference to C/Ds tensor descriptors
1583 const auto c_ds_desc_refs = concat_tuple_of_reference(
1584 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1585 generate_tie([&](auto i) -> const auto& // return type should be reference
1586 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1588
1589 // tuple of reference to C/Ds tensor descriptors
1590 const auto c_ds_buf_refs = concat_tuple_of_reference(
1591 tie(c_shuffle_block_buf),
1592 generate_tie([&](auto i) -> const auto& // return type should be reference
1593 { return ds_grid_buf[i]; },
1595
1596 // tuple of starting index of C/Ds blockwise copy
1597 const auto idx_c_ds_block_begin = container_concat(
1598 make_tuple(make_multi_index(0, 0, 0, 0)),
1600 [&](auto) {
1601 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
1602 },
1604
1605 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1606 c_grid_desc_mblock_mperblock_nblock_nperblock;
1607
1608 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
1609 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1610 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1611
1612 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
1614 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1616 decltype(c_ds_desc_refs),
1617 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1618 CElementwiseOperation,
1619 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1620 // support arbitray type
1621 Sequence<1,
1622 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1623 1,
1624 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1625 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1626 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1627 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1628 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1629 3, // index_t SrcVectorDim,
1630 3, // index_t DstVectorDim,
1631 CDEShuffleBlockTransferScalarPerVectors,
1636 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1637 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
1638 {c_ds_desc_refs,
1639 idx_c_ds_block_begin,
1640 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1641 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
1642 c_element_op};
1643
1644 constexpr auto sfc_c_vgpr =
1647 Sequence<CShuffleMXdlPerWavePerShuffle,
1648 CShuffleNXdlPerWavePerShuffle,
1649 1,
1650 1,
1651 1,
1652 N2,
1653 1,
1654 N4>>{};
1655
1656 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1657
1658 // space filling curve for shuffled blockwise C/D/E
1659 constexpr auto sfc_cde_block =
1662 Sequence<1,
1663 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1664 1,
1665 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1666
1667 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1668
1669 static_for<0, num_access, 1>{}([&](auto access_id) {
1670 // make sure it's safe to write to LDS
1672
1673 // each thread write its data from VGPR to LDS
1674 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1675 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1676 c_thread_buf,
1677 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1678 c_shuffle_block_buf);
1679
1680 // make sure it's safe to read from LDS
1682
1683 // each block copy its data from LDS to global
1684 cde_block_copy_lds_and_global.Run(
1685 c_ds_desc_refs,
1686 c_ds_buf_refs,
1687 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1688 tie(c_grid_buf));
1689
1690 if constexpr(access_id < num_access - 1)
1691 {
1692 constexpr auto cde_lds_and_global_step =
1693 sfc_cde_block.GetForwardStep(access_id);
1694
1695 // move on Ds
1696 static_for<0, NumDTensor, 1>{}([&](auto i) {
1697 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1698 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1699 });
1700
1701 // move on E
1702 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1703 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1704 I0,
1705 cde_lds_and_global_step);
1706 }
1707 });
1708 }
1709 }
1710};
1711
1712} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
constexpr auto BlockGemmABScalePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp:33
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
unsigned int uint32_t
Definition stdint.h:126
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:602
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:642
DsGridPointer p_ds_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:643
const AScaleType * p_a_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:646
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:649
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:641
const BScaleType * p_b_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:647
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:603
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:651
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:644
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:650
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:584
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:585
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:581
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:589
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:586
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:593
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:587
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:591
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:582
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:583
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:597
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:592
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:571
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:595
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:594
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:544
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:590
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:596
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:686
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:687
__device__ SplitKBatchOffset(Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:656
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:118
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:1205
remove_cvref_t< decltype(BlockGemmABScalePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:938
static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp:1184
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7r3.hpp:48
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340