21template <
typename ALayout,
29 typename CShuffleDataType,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CDEElementwiseOperation,
48 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
49 typename ABlockTransferThreadClusterArrangeOrder,
50 typename ABlockTransferSrcAccessOrder,
51 index_t ABlockTransferSrcVectorDim,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t ABlockTransferDstScalarPerVector_AK1,
54 bool AThreadTransferSrcResetCoordinateAfterRun,
56 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
57 typename BBlockTransferThreadClusterArrangeOrder,
58 typename BBlockTransferSrcAccessOrder,
59 index_t BBlockTransferSrcVectorDim,
60 index_t BBlockTransferSrcScalarPerVector,
61 index_t BBlockTransferDstScalarPerVector_BK1,
62 bool BThreadTransferSrcResetCoordinateAfterRun,
64 index_t CShuffleMRepeatPerShuffle,
65 index_t CShuffleNRepeatPerShuffle,
66 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
67 typename CDEShuffleBlockTransferScalarPerVectors,
70 typename ComputeTypeA = EDataType,
71 typename ComputeTypeB = ComputeTypeA,
72 bool PermuteA =
false,
73 bool PermuteB =
false>
86 AElementwiseOperation,
87 BElementwiseOperation,
88 CDEElementwiseOperation,
100 ABlockTransferThreadClusterLengths_AK0_M_AK1,
101 ABlockTransferThreadClusterArrangeOrder,
102 ABlockTransferSrcAccessOrder,
103 ABlockTransferSrcVectorDim,
104 ABlockTransferSrcScalarPerVector,
105 ABlockTransferDstScalarPerVector_AK1,
106 AThreadTransferSrcResetCoordinateAfterRun,
108 BBlockTransferThreadClusterLengths_BK0_N_BK1,
109 BBlockTransferThreadClusterArrangeOrder,
110 BBlockTransferSrcAccessOrder,
111 BBlockTransferSrcVectorDim,
112 BBlockTransferSrcScalarPerVector,
113 BBlockTransferDstScalarPerVector_BK1,
114 BThreadTransferSrcResetCoordinateAfterRun,
116 CShuffleMRepeatPerShuffle,
117 CShuffleNRepeatPerShuffle,
118 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
119 CDEShuffleBlockTransferScalarPerVectors,
139 AElementwiseOperation,
140 BElementwiseOperation,
141 CDEElementwiseOperation,
153 ABlockTransferThreadClusterLengths_AK0_M_AK1,
154 ABlockTransferThreadClusterArrangeOrder,
155 ABlockTransferSrcAccessOrder,
156 ABlockTransferSrcVectorDim,
157 ABlockTransferSrcScalarPerVector,
158 ABlockTransferDstScalarPerVector_AK1,
159 AThreadTransferSrcResetCoordinateAfterRun,
161 BBlockTransferThreadClusterLengths_BK0_N_BK1,
162 BBlockTransferThreadClusterArrangeOrder,
163 BBlockTransferSrcAccessOrder,
164 BBlockTransferSrcVectorDim,
165 BBlockTransferSrcScalarPerVector,
166 BBlockTransferDstScalarPerVector_BK1,
167 BThreadTransferSrcResetCoordinateAfterRun,
169 CShuffleMRepeatPerShuffle,
170 CShuffleNRepeatPerShuffle,
171 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172 CDEShuffleBlockTransferScalarPerVectors,
230 std::array<index_t, NumATensor> StrideAs_,
231 std::array<index_t, NumBTensor> StrideBs_,
232 std::array<index_t, NumDTensor> StrideDs_,
258 std::cout <<
"problem {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
263 std::cout <<
"}, " <<
"SBs: {";
270 std::cout <<
"SDs: { ";
278 <<
", " <<
"KP:" <<
KPadded <<
", " <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0
279 <<
", " <<
"MBlock: " <<
MBlock <<
", " <<
"NBlock: " <<
NBlock <<
"}"
305 __host__
Argument(std::array<const void*, NumATensor> p_as_grid_,
306 std::array<const void*, NumBTensor> p_bs_grid_,
307 std::array<const void*, NumDTensor> p_ds_grid_,
308 EDataType* p_e_grid_,
312 std::array<index_t, NumATensor> StrideAs_,
313 std::array<index_t, NumBTensor> StrideBs_,
314 std::array<index_t, NumDTensor> StrideDs_,
317 const BScaleType* p_b_scale_grid_,
319 AElementwiseOperation a_element_op_,
320 BElementwiseOperation b_element_op_,
321 CDEElementwiseOperation cde_element_op_,
322 bool is_reduce_ =
false)
347 p_as_grid(i) =
static_cast<const ADataType_*
>(p_as_grid_[i]);
355 p_bs_grid(i) =
static_cast<const BDataType_*
>(p_bs_grid_[i]);
361 p_ds_grid(i) =
static_cast<const DDataType*
>(p_ds_grid_[i]);
415 if constexpr(!PermuteB)
422 const int k0_offset = karg.
KRead * karg.
N;
438 if(k_id < karg.
KBatch - 1)
470 template <index_t NumberOfBuffers,
typename BScaleGr
idDesc_BN_AK>
471 __device__
static auto MakeBScale(
const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
472 const BScaleType* p_b_scale_grid,
476 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
478 static constexpr auto wmma =
480 static constexpr auto KPerThread = wmma.
selected_wmma.k_per_wmma;
482 static constexpr auto ScaleSliceSizeN = NRepeat;
483 static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
488 constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
494 auto b_scale_thread_copy =
497 decltype(b_scale_grid_desc_bn_ak),
498 decltype(b_scale_thread_desc),
505 b_scale_grid_desc_bn_ak,
507 b_thread_offset_k / ScaleBlockK));
510 b_scale_thread_desc.GetElementSpaceSize());
513 typename BlockwiseGemmPipe::template BScale<ScaleSliceSizeN,
518 decltype(b_scale_grid_desc_bn_ak),
519 decltype(b_scale_thread_copy),
520 decltype(b_scale_grid_buf),
521 decltype(b_scale_thread_buf),
522 decltype(b_scale_thread_desc)>;
524 return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
529 return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
532 template <
bool HasMainKBlockLoop,
535 typename EpilogueArgument>
540 const BScaleType* p_b_scale_grid,
542 const Problem& problem,
543 AElementwiseOperation a_element_op,
544 BElementwiseOperation b_element_op,
545 CDEElementwiseOperation cde_element_op,
546 EpilogueArgument& epilogue_args)
549 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
551 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
553 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
555 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
556 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
558 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
559 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
561 e_grid_desc_m_n, problem.MBlock, problem.NBlock);
570 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
572 const auto block_work_idx =
575 if(!block_2_ctile_map.ValidCTileIndex(
577 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
578 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
583 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
584 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
587 auto b_scale_struct =
MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id);
591 Base::template
Run<
decltype(as_grid_desc_ak0_m_ak1),
592 decltype(bs_grid_desc_bk0_n_bk1),
593 decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
594 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
595 decltype(b_scale_struct),
596 decltype(epilogue_args),
598 EGlobalMemoryDataOperation,
604 as_grid_desc_ak0_m_ak1,
605 bs_grid_desc_bk0_n_bk1,
606 ds_grid_desc_mblock_mperblock_nblock_nperblock,
607 e_grid_desc_mblock_mperblock_nblock_nperblock,
613 num_k_block_per_scale,
620 template <
bool HasMainKBlockLoop,
623 typename EpilogueArgument>
624 __device__
static void Run(
void* p_shared,
625 const SplitKBatchOffset& splitk_batch_offset,
627 EpilogueArgument& epilogue_args)
633 p_as_grid_splitk(i) =
static_cast<const ADataType_*
>(karg.p_as_grid[i]) +
634 splitk_batch_offset.a_k_split_offset[i];
641 p_bs_grid_splitk(i) =
static_cast<const BDataType_*
>(karg.p_bs_grid[i]) +
642 splitk_batch_offset.b_k_split_offset[i];
649 karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
650 karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
GemmSpecialization
Definition gemm_specialization.hpp:11
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:271
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:304
const BElementwiseOperation b_element_op
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:382
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:365
__host__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t StrideScaleB_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:305
const AElementwiseOperation a_element_op
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:381
const CDEElementwiseOperation cde_element_op
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:383
const BScaleType * p_b_scale_grid
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:380
AsGridPointer p_as_grid
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:375
EDataType * p_e_grid
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:378
DsGridPointer p_ds_grid
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:377
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:370
bool is_reduce
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:384
BsGridPointer p_bs_grid
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:376
index_t KBatch
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:291
index_t MPadded
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:292
index_t NBlock
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:299
index_t StrideE
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:289
__host__ void Print() const
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:256
index_t AK0
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:296
index_t MBlock
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:298
index_t BK0
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:297
index_t StrideScaleB
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:290
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:288
std::array< index_t, NumATensor > StrideAs
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:286
index_t KPadded
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:295
index_t N
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:284
std::array< index_t, NumBTensor > StrideBs
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:287
index_t K
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:285
index_t KRead
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:294
index_t M
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:283
index_t NPadded
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:293
__host__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t StrideScaleB_, index_t KBatch_)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:227
std::array< index_t, NumATensor > a_k_split_offset
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:457
index_t c_reduce_offset
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:460
index_t scale_k_split_offset
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:459
std::array< index_t, NumBTensor > b_k_split_offset
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:458
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:390
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:127
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:288
static __device__ constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::I2 static constexpr auto I2
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:323
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:293
decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:521
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:318
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::I1 static constexpr auto I1
static constexpr index_t NumATensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:311
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Block2CTileMap BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:467
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:405
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:378
static constexpr index_t NumBTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
static __device__ index_t GetKBlockPerScale()
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:527
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::AsDataType_ Tuple< ADataType > AsDataType_
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:222
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::I0 static constexpr auto I0
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:523
static constexpr index_t APackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:161
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::BlockwiseGemmPipe typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:463
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::BsDataType_ Tuple< BDataType > BsDataType_
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:223
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Base GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1Value, BK1Value, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB, true > Base
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:128
decltype(MakeBsGridPointer()) BsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:351
static constexpr index_t BPackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:168
decltype(MakeAsGridPointer()) AsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:350
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:278
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:536
static constexpr index_t NumDTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:508
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, EpilogueArgument &epilogue_args)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:624
static __device__ auto MakeBScale(const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const BScaleType *p_b_scale_grid, index_t block_n_id)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:471
static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:535
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:283
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:446
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::ThisThreadBlock ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:214
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:122
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:288
static constexpr auto I2
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:126
static constexpr auto I3
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:127
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor()), decltype(MakeBWmmaTileDescriptor()), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:546
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:323
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:293
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:318
decltype(MakeAsGridPointer()) AsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:350
static constexpr auto I1
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:125
static constexpr index_t NumATensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:311
static constexpr auto AK1Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:151
decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:521
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:405
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:446
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:378
static constexpr auto I6
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:130
static constexpr auto AK0Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:149
static constexpr index_t NumBTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
static constexpr auto I0
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:124
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:523
static constexpr index_t APackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:161
static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:535
static constexpr auto I7
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:131
static constexpr auto I4
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:128
static __device__ constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
static constexpr index_t BPackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:168
static constexpr auto BK1Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:152
static constexpr auto BK0Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:150
decltype(MakeBsGridPointer()) BsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:351
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:278
static constexpr index_t NumDTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:508
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:283
static constexpr auto I5
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:129
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
Definition utility/sequence.hpp:43
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition wmma_gemm.hpp:553
static constexpr auto selected_wmma
Definition wmma_gemm.hpp:636
Definition functional2.hpp:33
Definition device_base.hpp:197