device_grouped_gemm_xdl_fixed_nk.hpp Source File#
device_grouped_gemm_xdl_fixed_nk.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const index_t grid_size_grp, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:41
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ Default
Definition gemm_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:79
Definition functional2.hpp:33
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
BaseArgument()=default
BaseInvoker()=default
Definition device_grouped_gemm_xdl_fixed_nk.hpp:473
index_t grid_size_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:668
std::vector< GemmBiasTransKernelArg > gemm_desc_kernel_arg_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:662
index_t barrier_size_grp_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:670
index_t group_count_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:656
CDEElementwiseOperation c_element_op_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:660
AElementwiseOperation a_element_op_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:658
index_t grid_size_grp_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:669
index_t k_batch_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:673
ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK::Argument::grouped_gemm_kernel_args_dev
const void * grouped_gemm_kernel_args_dev
Definition device_grouped_gemm_xdl_fixed_nk.hpp:666
void UpdateKBatch(index_t k_batch)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:475
index_t sum_of_m
Definition device_grouped_gemm_xdl_fixed_nk.hpp:671
Argument(std::vector< const void * > &, std::vector< const void * > &, std::vector< std::array< const void *, NumDTensor > > &, std::vector< void * > &, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:501
std::vector< Tuple< index_t, index_t > > b_mtx_nraw_kraw_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:664
BElementwiseOperation b_element_op_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:659
std::vector< Tuple< index_t, index_t > > a_mtx_mraw_kraw_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:663
Definition device_grouped_gemm_xdl_fixed_nk.hpp:360
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:407
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &)=default
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:413
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops()=default
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops & operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &)=default
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops & operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &&)=default
static constexpr auto I1
Definition device_grouped_gemm_xdl_fixed_nk.hpp:362
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(const CGridDesc_M_N &c_grid_desc_m_n, index_t KBatch, index_t M01=8)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:384
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, index_t N, index_t KBatch, index_t M01=8)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:375
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:440
__host__ __device__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:401
static constexpr auto I0
Definition device_grouped_gemm_xdl_fixed_nk.hpp:361
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:391
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &&)=default
Definition device_grouped_gemm_xdl_fixed_nk.hpp:458
index_t M_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:465
index_t N_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:465
const void * a_ptr_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:460
void * e_ptr_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:463
index_t K_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:465
index_t StrideE_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:468
const void * b_ptr_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:461
std::array< index_t, NumDTensor > StrideDs_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:467
std::array< const void *, NumDTensor > ds_ptr_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:462
index_t StrideB_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:466
index_t StrideA_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:466
Definition device_grouped_gemm_xdl_fixed_nk.hpp:678
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:837
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_gemm_xdl_fixed_nk.hpp:682
DeviceOp::Argument Argument
Definition device_grouped_gemm_xdl_fixed_nk.hpp:679
Definition device_grouped_gemm_xdl_fixed_nk.hpp:314
UnderlyingBlockToCTileMap underlying_type
Definition device_grouped_gemm_xdl_fixed_nk.hpp:315
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:335
index_t block_start_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:354
__host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:342
__host__ __device__ OffsettedBlockToCTileMapMLoops(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off=0)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:317
index_t id_off_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:355
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:348
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:326
Block2ETileMap block_to_ctile_map_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:353
Definition device_grouped_gemm_xdl_fixed_nk.hpp:245
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:844
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:966
void SetKBatchSize(BaseArgument *p_arg, index_t kbatch) const override
Sets the k batch size.
Definition device_grouped_gemm_xdl_fixed_nk.hpp:1022
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op) override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:905
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:883
static constexpr auto I0
Definition device_grouped_gemm_xdl_fixed_nk.hpp:253
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_xdl_fixed_nk.hpp:309
OffsettedBlockToCTileMapMLoops< Block2ETileMap > GroupedGemmBlock2ETileMap
Definition device_grouped_gemm_xdl_fixed_nk.hpp:454
void SetKBatch(BaseArgument *p_arg, index_t k_batch) const override
Sets the k batch size.
Definition device_grouped_gemm_xdl_fixed_nk.hpp:1010
static constexpr index_t NumDTensor
Definition device_grouped_gemm_xdl_fixed_nk.hpp:251
ComputeType BComputeType
Definition device_grouped_gemm_xdl_fixed_nk.hpp:258
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops< MPerBlock, NPerBlock > Block2ETileMap
Definition device_grouped_gemm_xdl_fixed_nk.hpp:453
static constexpr auto NXdlPerWave32
Definition device_grouped_gemm_xdl_fixed_nk.hpp:249
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:919
DeviceGroupedGemm_Xdl_Fixed_NK DeviceOp
Definition device_grouped_gemm_xdl_fixed_nk.hpp:246
void SetDeviceKernelArgs(BaseArgument *p_arg, void *kernel_args) const override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:954
static constexpr auto I2
Definition device_grouped_gemm_xdl_fixed_nk.hpp:255
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition device_grouped_gemm_xdl_fixed_nk.hpp:978
std::string GetTypeString() const override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:925
static void SetKBatch(Argument &arg, index_t k_batch)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:1007
GridwiseGemmMultipleD_xdl_splitk_cshuffle< ADataType, BDataType, AComputeType, BComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumPrefetch, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer, ALDSType, BLDSType > GridwiseGemmBase
Definition device_grouped_gemm_xdl_fixed_nk.hpp:262
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &stream_config=StreamConfig{}) const override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:990
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_gemm_xdl_fixed_nk.hpp:248
static constexpr auto I1
Definition device_grouped_gemm_xdl_fixed_nk.hpp:254
static auto MakeInvoker()
Definition device_grouped_gemm_xdl_fixed_nk.hpp:901
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_xdl_fixed_nk.hpp:310
ComputeType AComputeType
Definition device_grouped_gemm_xdl_fixed_nk.hpp:257
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:888
Definition device_grouped_gemm_fixed_nk.hpp:34
Definition device_grouped_gemm.hpp:80
Structure representing single GEMM problem arguments.
Definition device_grouped_gemm.hpp:29