24template <
typename ALayout,
32 typename GemmAccDataType,
33 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
48 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
49 typename ABlockTransferThreadClusterArrangeOrder,
50 typename ABlockTransferSrcAccessOrder,
51 index_t ABlockTransferSrcVectorDim,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t ABlockTransferDstScalarPerVector_AK1,
55 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
56 typename BBlockTransferThreadClusterArrangeOrder,
57 typename BBlockTransferSrcAccessOrder,
58 index_t BBlockTransferSrcVectorDim,
59 index_t BBlockTransferSrcScalarPerVector,
60 index_t BBlockTransferDstScalarPerVector_BK1,
62 index_t CShuffleMXdlPerWavePerShuffle,
63 index_t CShuffleNXdlPerWavePerShuffle,
64 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
65 typename CDEShuffleBlockTransferScalarPerVectors,
68 typename ComputeTypeA = CDataType,
69 typename ComputeTypeB = ComputeTypeA,
70 typename LDSTypeA = ComputeTypeA,
71 typename LDSTypeB = ComputeTypeB>
80 AElementwiseOperation,
81 BElementwiseOperation,
82 CElementwiseOperation>
90 template <index_t NXdlPerWave_>
102 AElementwiseOperation,
103 BElementwiseOperation,
104 CElementwiseOperation,
116 ABlockTransferThreadClusterLengths_AK0_M_AK1,
117 ABlockTransferThreadClusterArrangeOrder,
118 ABlockTransferSrcAccessOrder,
119 ABlockTransferSrcVectorDim,
120 ABlockTransferSrcScalarPerVector,
121 ABlockTransferDstScalarPerVector_AK1,
124 BBlockTransferThreadClusterLengths_BK0_N_BK1,
125 BBlockTransferThreadClusterArrangeOrder,
126 BBlockTransferSrcAccessOrder,
127 BBlockTransferSrcVectorDim,
128 BBlockTransferSrcScalarPerVector,
129 BBlockTransferDstScalarPerVector_BK1,
132 CShuffleMXdlPerWavePerShuffle,
133 math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
134 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
135 CDEShuffleBlockTransferScalarPerVectors,
145 using Argument =
typename GridwiseGemm64::Argument;
149 template <
typename Gr
idwiseGemm>
150 float RunImp(
const typename GridwiseGemm::Argument& arg,
153 if(stream_config.log_level_ > 0)
158 if(!GridwiseGemm::CheckValidity(arg))
160 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
164 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
168 index_t k_grain = arg.KBatch * KPerBlock;
169 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
171 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
173 const auto Run = [&](
const auto& kernel) {
174 if(stream_config.flush_cache)
177 std::array<std::size_t, NumDTensor> DsSize;
181 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
182 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
183 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
184 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
187 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType);
189 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType);
191 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
192 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
196 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() *
sizeof(DDataType);
201 stream_config.rotating_count,
205 rotating_mem.Print();
207 auto run_flush_cache = [&]() {
214 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
216 arg_.M * arg_.N *
sizeof(CDataType),
217 stream_config.stream_id_));
232 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
234 arg.M * arg.N *
sizeof(CDataType),
235 stream_config.stream_id_));
238 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
242 constexpr index_t minimum_occupancy = []() {
249 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
257 if(has_main_k_block_loop)
287 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
297 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
309 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
311 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
323 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
325 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
338 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
340 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
353 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
355 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
368 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
370 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
382 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
384 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
399 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
409 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
421 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
423 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
435 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
437 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
450 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
452 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
465 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
467 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
480 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
482 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
494 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
496 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
515 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
538 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
564 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
587 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
645 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
685 if(CDEShuffleBlockTransferScalarPerVectors{}[
Number<0>{}] <= 1 && (arg.KBatch > 1))
694 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
708 std::array<const void*, NumDTensor> p_ds,
715 std::array<index_t, NumDTensor> StrideDs,
718 AElementwiseOperation a_element_op,
719 BElementwiseOperation b_element_op,
720 CElementwiseOperation c_element_op)
722 return Argument{
static_cast<const ADataType*
>(p_a),
723 static_cast<const BDataType*
>(p_b),
725 static_cast<CDataType*
>(p_c),
744 std::array<const void*, NumDTensor> p_ds,
751 std::array<ck::index_t, NumDTensor> StrideDs,
754 AElementwiseOperation a_element_op,
755 BElementwiseOperation b_element_op,
756 CElementwiseOperation c_element_op)
override
758 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
759 static_cast<const BDataType*
>(p_b),
761 static_cast<CDataType*
>(p_c),
778 return std::make_unique<Invoker>(
Invoker{});
784 auto str = std::stringstream();
786 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
790 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
798 str <<
"DeviceGemmXdlUniversal"
801 << std::string(ALayout::name)[0]
802 << std::string(BLayout::name)[0]
803 << std::string(CLayout::name)[0]
808 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
810 << MPerXDL<<
"x"<<NPerXDL <<
", "
812 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
814 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
815 <<
"BlkGemmPipelineScheduler: "
816 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
817 <<
"BlkGemmPipelineVersion: "
818 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
819 <<
"BlkGemmPipelinePrefetchStages: "
820 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:40
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:75
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:157
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1186
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:148
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:642
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:150
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:83
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:776
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:87
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:701
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:649
std::string GetTypeString() const override
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:782
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:84
static auto MakeInvoker()
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:739
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:742
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:86
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:142
typename GridwiseGemm64::Argument Argument
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:145
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:655
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:143
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:706
GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:91
Definition device_gemm_multiple_d.hpp:80
Definition flush_cache.hpp:174