23template <
typename ADataType,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CElementwiseOperation,
44 typename ABlockTransferThreadClusterLengths_K0_M_K1,
45 typename ABlockTransferSrcAccessOrder,
48 bool ABlockLdsAddExtraM,
49 typename BBlockTransferThreadClusterLengths_K0_N_K1,
50 typename BBlockTransferSrcAccessOrder,
53 bool BBlockLdsAddExtraN,
54 index_t CShuffleMRepeatPerShuffle,
55 index_t CShuffleNRepeatPerShuffle,
56 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
57 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
58 typename ComputeType = CDataType,
68 AElementwiseOperation,
69 BElementwiseOperation,
70 CElementwiseOperation,
82 template <index_t NXdlPerWave_>
92 AElementwiseOperation,
93 BElementwiseOperation,
94 CElementwiseOperation,
96 NumGemmKPrefetchStage,
105 ABlockTransferThreadClusterLengths_K0_M_K1,
106 ABlockTransferSrcAccessOrder,
107 ABlockTransferSrcVectorDim,
108 ABlockTransferScalarPerVector,
110 BBlockTransferThreadClusterLengths_K0_N_K1,
111 BBlockTransferSrcAccessOrder,
112 BBlockTransferSrcVectorDim,
113 BBlockTransferScalarPerVector,
115 CShuffleMRepeatPerShuffle,
116 CShuffleNRepeatPerShuffle,
117 CBlockTransferScalarPerVector_NWaveNPerXDL,
118 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
128 const BDataType* p_b_grid_,
129 CDataType* p_c_grid_,
141 AElementwiseOperation a_element_op_,
142 BElementwiseOperation b_element_op_,
143 CElementwiseOperation c_element_op_)
174 template <
typename Argument_>
180 template <
typename Gr
idwiseGemm>
183 if(stream_config.log_level_ > 0)
188 const auto kbatch = karg.k_batch;
189 auto arg = *
reinterpret_cast<const typename GridwiseGemm::Argument*
>(&karg);
190 if(!GridwiseGemm::CheckValidity(arg))
192 throw std::runtime_error(
193 "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
199 ck::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
200 const auto K0Padded = karg.K0Padded;
202 const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
206 const auto Run = [&](
const auto& kernel) {
208 hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
210 karg.M * karg.N *
sizeof(CDataType),
211 stream_config.stream_id_));
225 if(has_main_k0_block_loop)
234 AElementwiseOperation,
235 BElementwiseOperation,
236 CElementwiseOperation>;
247 AElementwiseOperation,
248 BElementwiseOperation,
249 CElementwiseOperation>;
263 AElementwiseOperation,
264 BElementwiseOperation,
265 CElementwiseOperation>;
276 AElementwiseOperation,
277 BElementwiseOperation,
278 CElementwiseOperation>;
293 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
326 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(karg));
339 const BDataType* p_b,
347 AElementwiseOperation a_element_op,
348 BElementwiseOperation b_element_op,
349 CElementwiseOperation c_element_op,
383 AElementwiseOperation a_element_op,
384 BElementwiseOperation b_element_op,
385 CElementwiseOperation c_element_op,
388 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
389 static_cast<const BDataType*
>(p_b),
390 static_cast<CDataType*
>(p_c),
410 return std::make_unique<Invoker>(
Invoker{});
416 auto str = std::stringstream();
418 std::map<LoopScheduler, std::string> LoopSchedToString{
421 std::map<PipelineVersion, std::string> PipelineVersionToString{
425 str <<
"DeviceGemmXdlSplitKCShuffle_LdsDirectLoad"
430 << K0PerBlock <<
", "
434 << MXdlPerWave <<
", "
435 << NXdlPerWave <<
", "
436 << ABlockTransferScalarPerVector <<
", "
437 << BBlockTransferScalarPerVector <<
", "
438 << CShuffleMRepeatPerShuffle <<
", "
439 << CShuffleNRepeatPerShuffle <<
", "
442 <<
" LoopScheduler: "
443 << LoopSchedToString[LoopSched] <<
", "
444 <<
"PipelineVersion: "
445 << PipelineVersionToString[PipelineVer] <<
", "
447 << NumGemmKPrefetchStage;
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
bool is_lds_direct_load_supported()
Definition host_utility/device_prop.hpp:101
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__global__ void kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, const Block2CTileMap &b2c_map, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:35
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v4
Definition gridwise_gemm_pipeline_selector.hpp:22
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:99
ck::GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferScalarPerVector, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeType >::CalculateMPadded __host__ static __device__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:190
ck::GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferScalarPerVector, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeType >::CalculateKPadded __host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:207
ck::GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferScalarPerVector, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeType >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:404
ck::GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferScalarPerVector, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeType >::DefaultBlock2CTileMap remove_cvref_t< decltype(MakeDefaultBlock2CTileMap())> DefaultBlock2CTileMap
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:547
ck::GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferScalarPerVector, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeType >::CalculateK0Padded __host__ static __device__ auto CalculateK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:200
ck::GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferScalarPerVector, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeType >::CalculateNPadded __host__ static __device__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:195
Definition device_base.hpp:197
Definition device_gemm_splitk.hpp:26
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:126
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t MPadded_, index_t NPadded_, index_t KPadded_, index_t K0Padded_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:127
BElementwiseOperation b_element_op
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:165
AElementwiseOperation a_element_op
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:164
CElementwiseOperation c_element_op
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:166
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:173
void Print(const Argument_ &karg)
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:175
float RunImp(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:181
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:290
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:72
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:408
static constexpr auto I1
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:78
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:74
typename GridwiseGemm64::DefaultBlock2CTileMap DefaultBlock2CTileMap
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:169
GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferScalarPerVector, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeType > GridwiseGemmBase
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:83
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:123
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ck::index_t KBatch=1) override
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:374
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:122
static bool IsSupportedArgument(const Argument &karg)
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:303
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, index_t KBatch)
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:338
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:333
std::string GetTypeString() const override
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:414
static constexpr auto I3
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:80
static auto MakeInvoker()
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:371
static constexpr auto I0
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:77
static constexpr auto I2
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:79
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:297
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:75