gridwise_fpAintB_gemm_wmma.hpp Source File

gridwise_fpAintB_gemm_wmma.hpp Source File#

Composable Kernel: gridwise_fpAintB_gemm_wmma.hpp Source File
gridwise_fpAintB_gemm_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/utility/env.hpp"
19
20namespace ck {
21
22template <typename GridwiseGemm,
23 typename ADataType,
24 typename BDataType,
25 typename ScaleDataType,
26 typename CDataType,
27 typename AGridDesc,
28 typename BGridDesc,
29 typename ScaleGridDesc,
30 typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
31 typename AElementwiseOperation,
32 typename BElementwiseOperation,
33 typename CElementwiseOperation,
34 typename Block2CTileMap,
35 bool HasMainKBlockLoop>
36__global__ void
37#if CK_USE_LAUNCH_BOUNDS
39#endif
40 kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid,
41 const BDataType* __restrict__ p_b_grid,
42 const ScaleDataType* __restrict__ p_scale_grid,
43 CDataType* __restrict__ p_c_grid,
44 const AGridDesc a_grid_desc,
45 const BGridDesc b_grid_desc,
46 const ScaleGridDesc scale_grid_desc,
47 const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
48 c_grid_desc_mblock_mperblock_nblock_nperblock,
49 const AElementwiseOperation a_element_op,
50 const BElementwiseOperation b_element_op,
51 const CElementwiseOperation c_element_op,
52 const Block2CTileMap block_2_ctile_map)
53{
54#if(defined(__gfx11__) || defined(__gfx12__))
55 __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
56
57 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
58 p_b_grid,
59 p_scale_grid,
60 p_c_grid,
61 p_shared,
62 a_grid_desc,
63 b_grid_desc,
64 scale_grid_desc,
65 c_grid_desc_mblock_mperblock_nblock_nperblock,
66 a_element_op,
67 b_element_op,
68 c_element_op,
69 block_2_ctile_map);
70#else
71 ignore = p_a_grid;
72 ignore = p_b_grid;
73 ignore = p_scale_grid;
74 ignore = p_c_grid;
75 ignore = a_grid_desc;
76 ignore = b_grid_desc;
77 ignore = scale_grid_desc;
78 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
79 ignore = a_element_op;
80 ignore = b_element_op;
81 ignore = c_element_op;
82 ignore = block_2_ctile_map;
83#endif // end of if (defined(__gfx11__))
84}
85
86// Assume B is Col-Major
87template <index_t BlockSize,
88 typename ADataType,
89 typename BDataType,
90 typename ScaleDataType,
91 typename AccDataType,
92 typename CShuffleDataType,
93 typename CDataType,
94 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
95 typename AGridDesc,
96 typename BGridDesc,
97 typename ScaleGridDesc,
98 typename CGridDesc_M_N,
99 typename AElementwiseOperation,
100 typename BElementwiseOperation,
101 typename CElementwiseOperation,
102 index_t MPerBlock,
103 index_t NPerBlock,
104 index_t KPerBlock,
105 index_t MPerWmma,
106 index_t NPerWmma,
107 index_t K1Value,
108 index_t MRepeat,
109 index_t NRepeat,
110 typename ABlockTransferThreadClusterLengths_K0_M_K1,
111 typename ABlockTransferThreadClusterArrangeOrder,
112 typename ABlockTransferSrcAccessOrder,
113 index_t ABlockTransferSrcVectorDim,
114 index_t ABlockTransferSrcScalarPerVector,
115 index_t ABlockTransferDstScalarPerVector_K1,
116 bool AThreadTransferSrcResetCoordinateAfterRun,
117 bool AEnableLds,
118 bool ABlockLdsExtraM,
119 typename BBlockTransferThreadClusterLengths_K0_N_K1,
120 typename BBlockTransferThreadClusterArrangeOrder,
121 typename BBlockTransferSrcAccessOrder,
122 index_t BBlockTransferSrcVectorDim,
123 index_t BBlockTransferSrcScalarPerVector,
124 index_t BBlockTransferDstScalarPerVector_K1,
125 bool BThreadTransferSrcResetCoordinateAfterRun,
126 bool BEnableLds,
127 bool BBlockLdsExtraN,
128 index_t CShuffleMRepeatPerShuffle,
129 index_t CShuffleNRepeatPerShuffle,
130 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
131 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
132 index_t NumGemmKPrefetchStage = 1,
136{
137 static constexpr auto I0 = Number<0>{};
138 static constexpr auto I1 = Number<1>{};
139 static constexpr auto I2 = Number<2>{};
140 static constexpr auto I3 = Number<3>{};
141 static constexpr auto I4 = Number<4>{};
142 static constexpr auto I5 = Number<5>{};
143 static constexpr auto I6 = Number<6>{};
144 static constexpr auto I7 = Number<7>{};
145
146 // FIX ME: To be deprecated
147 static constexpr auto K1 = Number<K1Value>{};
148
149 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
150 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
151 static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
152
154
157 NumGemmKPrefetchStage,
158 LoopSched,
159 AEnableLds,
160 BEnableLds>())>;
161
162 // Describe how data store to (LDS/VGPR) buffer from Global memory
163 __host__ __device__ static constexpr auto MakeABlockDescriptor()
164 {
165 constexpr auto a_block_desc = [&]() {
166 if constexpr(AEnableLds)
167 {
168 // K0->M->K1 Per Block
169 constexpr auto K0PerBlock = KPerBlock / K1;
170 constexpr auto max_lds_align = K1;
171
172 if constexpr(ABlockLdsExtraM)
173 {
177 }
178 else
179 {
181 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
182 }
183 }
184 else
185 {
186 constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
187 constexpr auto K0PerWmma = WmmaK / 2 / K1;
188 // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
192 I1,
194 I1,
195 I1,
196 K1),
200 K1,
201 K1,
202 K1,
203 I1));
204 }
205 }();
206
207 return a_block_desc;
208 }
209
210 __host__ __device__ static constexpr auto MakeBBlockDescriptor()
211 {
212 constexpr auto b_block_desc = [&]() {
213 if constexpr(BEnableLds)
214 {
215 // K0->N->K1 Per Block
216 constexpr auto K0PerBlock = KPerBlock / K1;
217 constexpr auto max_lds_align = K1;
218
219 if constexpr(BBlockLdsExtraN)
220 {
224 }
225 else
226 {
228 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
229 }
230 }
231 else
232 {
233 constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
234 constexpr auto K0PerWmma = WmmaK / 2 / K1;
235 // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
239 I1,
241 I1,
242 I1,
243 K1),
247 K1,
248 K1,
249 K1,
250 I1));
251 }
252 }();
253
254 return b_block_desc;
255 }
256
257 __host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
258 {
259 constexpr auto a_block_copy_step = [&]() {
260 if constexpr(AEnableLds)
261 {
262 constexpr auto K0PerBlock = KPerBlock / K1;
263
264 return make_multi_index(K0PerBlock, 0, 0);
265 }
266 else
267 {
268 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
269
270 return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
271 }
272 }();
273
274 return a_block_copy_step;
275 }
276
277 __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
278 {
279 constexpr auto b_block_copy_step = [&]() {
280 if constexpr(BEnableLds)
281 {
282 constexpr auto K0PerBlock = KPerBlock / K1;
283
284 return make_multi_index(K0PerBlock, 0, 0);
285 }
286 else
287 {
288 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
289
290 return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
291 }
292 }();
293
294 return b_block_copy_step;
295 }
296
297 // Describe how data read from (LDS/VGPR) buffer
298 template <typename ABlockDesc_>
299 __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
300 {
301
302 constexpr auto a_wave_desc = [&]() {
303 if constexpr(AEnableLds)
304 {
305 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
306 constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
307 constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
308#ifdef __gfx12__
309 constexpr auto A_KRow = I2;
310#else
311 constexpr auto A_KRow = I1;
312#endif
314 ABlockDesc_{},
321 }
322 else
323 {
324 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
325 constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
326 constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
327 constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
328 constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
329
330 // Err: merge transform cause non-constexpr issue
331
332 // return transform_tensor_descriptor(
333 // ABlockDesc_{},
334 // make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
335 // make_pass_through_transform(Number<MRepeat>{}),
336 // make_pass_through_transform(I1),
337 // make_pass_through_transform(I1),
338 // make_pass_through_transform(Number<A_K1>{})),
339 // make_tuple(Sequence<0, 3>{},
340 // Sequence<1>{},
341 // Sequence<2>{},
342 // Sequence<4>{},
343 // Sequence<5>{}),
344 // make_tuple(
345 // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
346 // Sequence<4>{}));
347
348 // Workaround, Freeze transform
351 I1,
353 I1,
354 Number<A_K1>{}));
355 }
356 }();
357
358 return a_wave_desc;
359 }
360
361 template <typename BBlockDesc_>
362 __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
363 {
364 constexpr auto b_wave_desc = [&]() {
365 if constexpr(BEnableLds)
366 {
367 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
368 constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
369 constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
370#ifdef __gfx12__
371 constexpr auto B_KRow = I2;
372#else
373 constexpr auto B_KRow = I1;
374#endif
376 BBlockDesc_{},
383 }
384 else
385 {
386 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
387 constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
388 constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
389 constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
390 constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
391
392 // Workaround, Freeze transform
395 I1,
397 I1,
398 Number<B_K1>{}));
399 }
400 }();
401
402 return b_wave_desc;
403 }
404
405 __host__ __device__ static constexpr auto
406 // *Caution Here repeat is shuffle repeat
408 {
409 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
413 I1,
415
416 return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
417 }
418
419 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
420 template <typename Block2CTileMap>
421 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
422 const BGridDesc& b_grid_desc,
423 const CGridDesc_M_N& c_grid_desc_m_n,
424 const Block2CTileMap& block_2_ctile_map)
425 {
426 static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
427 "wrong! K1 need to be known at compile-time");
428
429 static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
430 (NPerBlock % (NRepeat * NPerWmma)) == 0,
431 "Invalid tuning param!");
432
433 const auto GetAProblemsizeMK = [&]() {
434 if constexpr(AEnableLds)
435 {
436 return make_tuple(a_grid_desc.GetLength(I1),
437 a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
438 }
439 else
440 {
441 return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
442 a_grid_desc.GetLength(I5),
443 a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
444 a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
445 }
446 };
447
448 const auto GetBProblemsizeNK = [&]() {
449 if constexpr(BEnableLds)
450 {
451 return make_tuple(b_grid_desc.GetLength(I1),
452 b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
453 }
454 else
455 {
456 return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
457 b_grid_desc.GetLength(I5),
458 b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
459 b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
460 }
461 };
462
463 const auto M = GetAProblemsizeMK()[I0];
464 const auto N = GetBProblemsizeNK()[I0];
465 const auto K = GetAProblemsizeMK()[I1];
466
467 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
468 K == GetBProblemsizeNK()[I1]))
469 {
470 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
471 {
472 printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n",
473 GetAProblemsizeMK()[I0],
474 GetAProblemsizeMK()[I1],
475 GetBProblemsizeNK()[I0],
476 GetBProblemsizeNK()[I1],
477 c_grid_desc_m_n.GetLength(I0),
478 c_grid_desc_m_n.GetLength(I1));
479 printf("GridwiseOp err: ProblemSize check");
480 }
481 return false;
482 }
483
484 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
485 {
486 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
487 {
488 printf("GridwiseOp err: ProblemSize division");
489 }
490 return false;
491 }
492
493 // check gridwise gemm pipeline
494 const auto num_k_loop = K / KPerBlock;
495
496 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
497 {
498 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
499 {
500 printf("GridwiseOp err: Pipeline not support this k_loop");
501 }
502 return false;
503 }
504
505 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
506 {
507 return false;
508 }
509
510 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
511 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
512
513 if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
514 b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
515 {
516 return false;
517 }
518 return true;
519 }
520
521 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
522 {
523 const index_t num_loop = K / KPerBlock;
524
525 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
526 }
527
528 __host__ __device__ static constexpr auto
529 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
530 {
531 const auto M = c_grid_desc_m_n.GetLength(I0);
532 const auto N = c_grid_desc_m_n.GetLength(I1);
533
534 const auto MBlock = M / MPerBlock;
535 const auto NBlock = N / NPerBlock;
536
537 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
538 c_grid_desc_m_n,
543
544 return c_grid_desc_mblock_mperblock_nblock_nperblock;
545 }
546
547 // return block_id to C matrix tile idx (m0, n0) mapping
548 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
549 const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
550 {
552 c_grid_desc_m_n);
553 }
554
557 CGridDesc_M_N{}))>;
559 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
560
562 {
563 // LDS allocation for A and Dequantized B: be careful of DataType
564 // scale would not put into LDS.
565 using LDS_ADataType = ADataType;
566 using LDS_BDataType = ADataType;
567 using LDS_CDataType = CShuffleDataType;
568 static constexpr auto max_lds_align = K1;
569
570 static constexpr auto a_block_space_size_aligned =
571 AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
573 : 0;
574 static constexpr auto b_block_space_size_aligned =
575 BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
577 : 0;
578
579 static constexpr auto a_block_space_offset = 0;
580 // B would be dequantize to ADataType before enter LDS
581 // b_lds_offset = LDS size allocated for a in byte / LDS_BDataType
582 static constexpr auto b_block_space_offset =
584 sizeof(LDS_BDataType);
585
586 // LDS allocation for C shuffle in LDS
587 static constexpr auto c_shuffle_block_space_size =
589 .GetElementSpaceSize();
590
591 static constexpr auto c_shuffle_block_space_offset = 0;
592
593 static constexpr auto lds_size =
597 };
598
599 template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
600 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
601 const BDataType* __restrict__ p_b_grid,
602 const ScaleDataType* __restrict__ p_scale_grid,
603 CDataType* __restrict__ p_c_grid,
604 void* __restrict__ p_shared,
605 const AGridDesc& a_grid_desc,
606 const BGridDesc& b_grid_desc,
607 const ScaleGridDesc& scale_grid_desc,
609 c_grid_desc_mblock_mperblock_nblock_nperblock,
610 const AElementwiseOperation& a_element_op,
611 const BElementwiseOperation& b_element_op,
612 const CElementwiseOperation& c_element_op,
613 const Block2CTileMap& block_2_ctile_map)
614 {
615 // clang-format off
616/*******************************************************************************/
617// Memory buffer zone.
618 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
619 p_a_grid, a_grid_desc.GetElementSpaceSize());
620 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
621 p_b_grid, b_grid_desc.GetElementSpaceSize());
622 const auto scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
623 p_scale_grid, scale_grid_desc.GetElementSpaceSize());
625 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
626
627/*******************************************************************************/
628// BlockIdx.x -> [BlockId.m, BlockId.n]
629 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
630 if(!block_2_ctile_map.ValidCTileIndex(
631 block_work_idx,
632 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
633 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
634 { return; }
635
636 // Store BlockId into SGPR
637 const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
638 const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
639
640/*******************************************************************************/
641// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
642 const auto K = [&](){
643 if constexpr(AEnableLds){
644 return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
645 }
646 else{
647 return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
648 * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
649 }
650 }();
651
652 constexpr auto a_block_desc = MakeABlockDescriptor();
653 constexpr auto b_block_desc = MakeBBlockDescriptor();
654
655 auto a_block_trait = [&](){
656 // A matrix blockwise copy
657 if constexpr(AEnableLds)
658 {
659 constexpr auto K0PerBlock = KPerBlock/ K1;
661 static_cast<ADataType*>(p_shared),
663
664 auto a_blockwise_copy =
666/* typename SrcElementwiseOperation, */ AElementwiseOperation,
667/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
668/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
669/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
670/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
671/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
672/* typename SrcData, */ ADataType,
673/* typename DstData, */ ADataType,
674/* typename SrcDesc, */ decltype(a_grid_desc),
675/* typename DstDesc, */ decltype(a_block_desc),
676/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
677/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
678/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
679/* index_t DstVectorDim, */ 2,
680/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
681/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
682/* index_t SrcScalarStrideInVector, */ 1,
683/* index_t DstScalarStrideInVector, */ 1,
684/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
685/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
686 NumGemmKPrefetchStage>(
687 a_grid_desc,
688 make_multi_index(0, m_block_data_idx_on_grid, 0),
689 a_element_op,
690 a_block_desc,
691 make_multi_index(0, 0, 0),
693
694 return make_tuple(a_block_buf, a_blockwise_copy);
695 }
696 else
697 {
698 // Thread-wise copy
699 // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
700 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
701 constexpr auto K0PerWmma = WmmaK/2/K1Value;
703 a_block_desc.GetElementSpaceSize());
704
705 // Limitation: NumDim of Src and Dst descriptor should be identical
706 auto a_blockwise_copy =
708 ADataType,
709 decltype(a_grid_desc),
710 decltype(a_block_desc),
713 I1,
715 I1,
716 I1,
719 6,
720 ABlockTransferSrcScalarPerVector,
721 AThreadTransferSrcResetCoordinateAfterRun,
722 true>(
723 a_grid_desc,
725 m_block_data_idx_on_grid/(MWaves * MPerWmma),
727 0,
728 (get_thread_local_1d_id() % 32 )/ 16,
730 0));
731
732 return make_tuple(a_block_buf, a_blockwise_copy);
733 }
734 };
735
736 auto b_block_trait = [&](){
737 if constexpr(BEnableLds)
738 {
739 constexpr auto K0PerBlock = KPerBlock/ K1;
741 static_cast<ADataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
743
744 auto b_blockwise_copy =
746/* typename SrcElementwiseOperation, */ BElementwiseOperation,
747/* typename ScaleElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
748/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
749/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
750/* typename BlockSliceLengths, */ Sequence<K0PerBlock, NPerBlock, K1>,
751/* typename BlockScaleSliceLengths, */ Sequence<K0PerBlock, NPerBlock, I1>,
752/* typename ThreadClusterLengths, */ BBlockTransferThreadClusterLengths_K0_N_K1,
753/* typename ThreadClusterArrangeOrder, */ BBlockTransferThreadClusterArrangeOrder,
754/* typename SrcData, */ BDataType,
755/* typename ScaleData, */ ScaleDataType,
756/* typename DstData, */ ADataType,
757/* typename SrcDesc, */ decltype(b_grid_desc),
758/* typename ScaleDesc, */ decltype(scale_grid_desc),
759/* typename DstDesc, */ decltype(b_block_desc),
760/* typename SrcDimAccessOrder, */ BBlockTransferSrcAccessOrder,
761/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
762/* index_t SrcVectorDim, */ BBlockTransferSrcVectorDim,
763/* index_t DstVectorDim, */ 2,
764/* index_t SrcScalarPerVector, */ BBlockTransferSrcScalarPerVector,
765/* index_t ScaleScalarPerVector, */ 1,
766/* index_t DstScalarPerVector, */ BBlockTransferDstScalarPerVector_K1,
767/* index_t SrcScalarStrideInVector, */ 1,
768/* index_t ScaleScalarStrideInVector, */ 1,
769/* index_t DstScalarStrideInVector, */ 1,
770/* bool ThreadTransferSrcResetCoordinateAfterRun, */ BThreadTransferSrcResetCoordinateAfterRun,
771/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
772 NumGemmKPrefetchStage>(
773 b_grid_desc,
774 make_multi_index(0, n_block_data_idx_on_grid, 0),
775 b_element_op,
776 scale_grid_desc,
777 make_multi_index(0, n_block_data_idx_on_grid, 0),
779 b_block_desc,
780 make_multi_index(0, 0, 0),
782
783 return make_tuple(b_block_buf, b_blockwise_copy);
784 }
785 else
786 {
787 // Thread-wise copy
788 // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
789 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
790 constexpr auto K0PerWmma = WmmaK/2/K1Value;
792 b_block_desc.GetElementSpaceSize());
793
794 // Limitation: NumDim of Src and Dst descriptor should be identical
795 auto b_blockwise_copy =
797 BDataType,
798 decltype(b_grid_desc),
799 decltype(b_block_desc),
802 I1,
804 I1,
805 I1,
808 6,
809 BBlockTransferSrcScalarPerVector,
810 BThreadTransferSrcResetCoordinateAfterRun,
811 true>(
812 b_grid_desc,
814 n_block_data_idx_on_grid/(NWaves * NPerWmma),
816 0,
817 (get_thread_local_1d_id() % 32 )/ 16,
819 0));
820
821 return make_tuple(b_block_buf, b_blockwise_copy);
822 }
823 };
824
825 auto a_block_buf = a_block_trait()[I0];
826 auto a_blockwise_copy = a_block_trait()[I1];
827
828 auto b_block_buf = b_block_trait()[I0];
829 auto b_blockwise_copy = b_block_trait()[I1];
830/*******************************************************************************/
831 // GEMM
832 constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
833
834 auto blockwise_gemm =
835 BlockwiseGemmWMMA<BlockSize,
836 ADataType,
837 ADataType, //Dequantized
838 AccDataType,
839 decltype(MakeAWaveDescriptor(a_block_desc)),
840 decltype(MakeBWaveDescriptor(b_block_desc)),
841 MPerBlock,
842 NPerBlock,
843 KPerBlock,
844 MPerWmma,
845 NPerWmma,
846 MRepeat,
847 NRepeat,
848 KPack,
849 AEnableLds,
850 BEnableLds>{};
851
852 // Prepare Register for C matrix
853 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
854
855/*******************************************************************************/
856 // Shift Per SUB_K
857 constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
858 constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
859
860 // gridwise GEMM pipeline
861 const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
862 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
863 a_block_desc,
864 a_blockwise_copy,
865 a_grid_buf,
866 a_block_buf,
867 a_block_slice_copy_step,
868 b_grid_desc,
869 b_block_desc,
870 b_blockwise_copy,
871 b_grid_buf,
872 b_block_buf,
873 b_block_slice_copy_step,
874 scale_grid_desc,
875 scale_grid_buf,
876 blockwise_gemm,
877 c_thread_buf,
878 KBlockMainLoop);
879/*******************************************************************************/
880 // write out to C, implement shuffle
881 {
882 // C mapping in single thread.
883 constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
884 blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
885
886 // C mapping in single block
887 constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
888 blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
889
890 constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
891 constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
892 constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4);
893 constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
894 constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
895
896 // LDS descriptor, shuffle and write out in MRepeat x NRepeat times
897 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
899
900 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
901 static_cast<CShuffleDataType*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
903
904 constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
905 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
909 Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
910 MWave, // MWave
911 MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
912 MAccVgprs)),
915 Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
916 NWave, // NWave
917 NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
920
921 // calculate origin of thread output tensor on global memory
922 // blockwise GEMM c matrix starting index
923 const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
924
925 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
926 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
927
928 const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
930 make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
933
934 const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
936 make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
939
940 const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
941 make_multi_index(m_thread_data_on_block));
942
943 const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
944 make_multi_index(n_thread_data_on_block));
945
946 // shuffle: threadwise copy C from VGPR to LDS
947 auto c_thread_copy_vgpr_to_lds =
949 CShuffleDataType,
950 decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
951 decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
953 Sequence<CShuffleMRepeatPerShuffle,
954 I1,
955 I1,
956 CShuffleNRepeatPerShuffle,
957 I1,
958 I1,
959 MAccVgprs>,
961 6,
962 1, // vector write pixel
964 1,
965 true>{
966 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
968 m_thread_data_on_block_idx[I1],
969 m_thread_data_on_block_idx[I2],
970 0,
971 n_thread_data_on_block_idx[I1],
972 n_thread_data_on_block_idx[I2],
973 m_thread_data_on_block_idx[I3]),
975
976 // shuffle: blockwise copy C from LDS to global
977 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
978 ThisThreadBlock, // ThreadGroup
979 CElementwiseOperation, // ElementwiseOperation,
980 CGlobalMemoryDataOperation, // DstInMemOp,
981 Sequence<1,
982 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
983 1,
984 CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
985 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
986 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
987 CShuffleDataType, // typename SrcData,
988 CDataType, // typename DstData,
989 decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
990 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
991 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
992 3, // index_t VectorDim,
993 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
994 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
995 false> // bool ThreadTransferDstResetCoordinateAfterRun>
996 {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
997 make_multi_index(0, 0, 0, 0),
998 c_grid_desc_mblock_mperblock_nblock_nperblock,
999 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1000 c_element_op};
1001
1002 // space filling curve for local reg & global memory
1003 // space filling curve for threadwise C in VGPR
1004 constexpr auto sfc_c_vgpr =
1007 Sequence<CShuffleMRepeatPerShuffle,
1008 1,
1009 1,
1010 CShuffleNRepeatPerShuffle,
1011 1,
1012 1,
1013 MAccVgprs>>{};
1014
1015 // space filling curve for shuffled blockwise C in global mem
1016 constexpr auto sfc_c_global =
1019 Sequence<1,
1020 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1021 1,
1022 CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
1023
1024 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1025
1026 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1027
1028 static_for<0, num_access, 1>{}([&](auto access_id) {
1029 // make sure it's safe to write to LDS
1031
1032 // each thread write its data from VGPR to LDS
1033 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1034 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1035 c_thread_buf,
1036 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1037 c_shuffle_block_buf);
1038
1039 // make sure it's safe to read from LDS
1041
1042 // each block copy its data from LDS to global
1043 c_shuffle_block_copy_lds_to_global.Run(
1044 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1045 c_shuffle_block_buf,
1046 c_grid_desc_mblock_mperblock_nblock_nperblock,
1047 c_grid_buf);
1048
1049 if constexpr(access_id < num_access - 1)
1050 {
1051 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1052
1053 // move on C
1054 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1055 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1056 }
1057 });
1058 }
1059 // clang-format on
1060 }
1061};
1062
1063} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ weight_only
Definition gridwise_gemm_pipeline_selector.hpp:23
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__global__ void kernel_fpAintB_gemm_wmma(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, const ScaleDataType *__restrict__ p_scale_grid, CDataType *__restrict__ p_c_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const ScaleGridDesc scale_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_fpAintB_gemm_wmma.hpp:40
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition block_to_ctile_map.hpp:261
Definition blockwise_gemm_wmma.hpp:550
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_wmma.hpp:585
Definition gridwise_fpAintB_gemm_wmma.hpp:562
static constexpr auto c_shuffle_block_space_size
Definition gridwise_fpAintB_gemm_wmma.hpp:587
static constexpr auto max_lds_align
Definition gridwise_fpAintB_gemm_wmma.hpp:568
static constexpr auto a_block_space_size_aligned
Definition gridwise_fpAintB_gemm_wmma.hpp:570
static constexpr auto lds_size
Definition gridwise_fpAintB_gemm_wmma.hpp:593
ADataType LDS_BDataType
Definition gridwise_fpAintB_gemm_wmma.hpp:566
static constexpr auto c_shuffle_block_space_offset
Definition gridwise_fpAintB_gemm_wmma.hpp:591
ADataType LDS_ADataType
Definition gridwise_fpAintB_gemm_wmma.hpp:565
static constexpr auto b_block_space_offset
Definition gridwise_fpAintB_gemm_wmma.hpp:582
static constexpr auto b_block_space_size_aligned
Definition gridwise_fpAintB_gemm_wmma.hpp:574
CShuffleDataType LDS_CDataType
Definition gridwise_fpAintB_gemm_wmma.hpp:567
static constexpr auto a_block_space_offset
Definition gridwise_fpAintB_gemm_wmma.hpp:579
Definition gridwise_fpAintB_gemm_wmma.hpp:136
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_fpAintB_gemm_wmma.hpp:529
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_fpAintB_gemm_wmma.hpp:421
__host__ static __device__ constexpr auto MakeAWaveDescriptor(const ABlockDesc_ &)
Definition gridwise_fpAintB_gemm_wmma.hpp:299
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_fpAintB_gemm_wmma.hpp:521
__host__ static __device__ constexpr auto MakeBWaveDescriptor(const BBlockDesc_ &)
Definition gridwise_fpAintB_gemm_wmma.hpp:362
__host__ static __device__ constexpr auto MakeABlockDescriptor()
Definition gridwise_fpAintB_gemm_wmma.hpp:163
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, const ScaleDataType *__restrict__ p_scale_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const ScaleGridDesc &scale_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_fpAintB_gemm_wmma.hpp:600
__host__ static __device__ constexpr auto MakeBBlockSliceCopyStep()
Definition gridwise_fpAintB_gemm_wmma.hpp:277
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_fpAintB_gemm_wmma.hpp:548
__host__ static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition gridwise_fpAintB_gemm_wmma.hpp:407
__host__ static __device__ constexpr auto MakeBBlockDescriptor()
Definition gridwise_fpAintB_gemm_wmma.hpp:210
__host__ static __device__ constexpr auto MakeABlockSliceCopyStep()
Definition gridwise_fpAintB_gemm_wmma.hpp:257
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer with dequantization.
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:56
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition is_known_at_compile_time.hpp:14
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129