blockwise_gemm_xdlops.hpp Source File#
blockwise_gemm_xdlops.hpp
Go to the documentation of this file.
399// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition ck.hpp:209
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition blockwise_gemm_smfmac_xdlops.hpp:44
static constexpr index_t KPerBlock
Definition blockwise_gemm_smfmac_xdlops.hpp:58
__host__ static __device__ constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
Definition blockwise_gemm_xdlops.hpp:278
static constexpr index_t A_K1
Definition blockwise_gemm_smfmac_xdlops.hpp:63
static constexpr auto c_thread_desc_
Definition blockwise_gemm_smfmac_xdlops.hpp:427
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
Definition blockwise_gemm_xdlops.hpp:169
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_xdlops.hpp:83
static constexpr auto I2
Definition blockwise_gemm_smfmac_xdlops.hpp:47
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_xdlops.hpp:260
conditional_t< is_same_v< ComputeTypeB, ck::tf32_t >, float, ComputeTypeB > ElementDataTypeB
Definition blockwise_gemm_xdlops.hpp:54
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_xdlops.hpp:108
static constexpr index_t WaveSize
Definition blockwise_gemm_smfmac_xdlops.hpp:54
__host__ static __device__ constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
Definition blockwise_gemm_xdlops.hpp:290
static constexpr index_t KPerThread
Definition blockwise_gemm_smfmac_xdlops.hpp:69
static constexpr index_t B_K1
Definition blockwise_gemm_smfmac_xdlops.hpp:64
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_xdlops.hpp:97
static constexpr index_t MPerBlock
Definition blockwise_gemm_smfmac_xdlops.hpp:56
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition blockwise_gemm_smfmac_xdlops.hpp:76
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:226
static constexpr auto b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_smfmac_xdlops.hpp:295
static constexpr index_t NPerBlock
Definition blockwise_gemm_smfmac_xdlops.hpp:57
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_xdlops.hpp:150
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:200
ck::BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:213
static constexpr auto I0
Definition blockwise_gemm_smfmac_xdlops.hpp:45
static constexpr auto a_thread_desc_
Definition blockwise_gemm_smfmac_xdlops.hpp:419
ThreadwiseTensorSliceTransfer_v4< FloatB, ComputeTypeB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_smfmac_xdlops.hpp:440
BThreadCopy b_thread_copy_
Definition blockwise_gemm_smfmac_xdlops.hpp:451
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_smfmac_xdlops.hpp:50
ck::BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_xdlops.hpp:243
static constexpr auto a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_smfmac_xdlops.hpp:294
AThreadCopy a_thread_copy_
Definition blockwise_gemm_smfmac_xdlops.hpp:450
static constexpr index_t NWaves
Definition blockwise_gemm_smfmac_xdlops.hpp:53
ThreadwiseTensorSliceTransfer_v4< FloatA, ComputeTypeA, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_smfmac_xdlops.hpp:430
static constexpr auto xdlops_gemm
Definition blockwise_gemm_smfmac_xdlops.hpp:66
conditional_t< is_same_v< ComputeTypeA, ck::tf32_t >, float, ComputeTypeA > ElementDataTypeA
Definition blockwise_gemm_xdlops.hpp:52
static constexpr index_t B_K0
Definition blockwise_gemm_smfmac_xdlops.hpp:62
static constexpr auto b_thread_desc_
Definition blockwise_gemm_smfmac_xdlops.hpp:423
static constexpr index_t A_K0
Definition blockwise_gemm_smfmac_xdlops.hpp:61
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_xdlops.hpp:306
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_xdlops.hpp:121
static constexpr auto I3
Definition blockwise_gemm_smfmac_xdlops.hpp:48
static constexpr auto I1
Definition blockwise_gemm_smfmac_xdlops.hpp:46
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_xdlops.hpp:85
static constexpr index_t MWaves
Definition blockwise_gemm_smfmac_xdlops.hpp:52
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:187
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_xdlops.hpp:923
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_xdlops.hpp:722
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:876
ThreadwiseTensorSliceTransfer_v4< FloatAB, FloatAB, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_xdlops.hpp:1009
static constexpr index_t A_K0
Definition blockwise_gemm_xdlops.hpp:702
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:835
static constexpr auto xdlops_gemm
Definition blockwise_gemm_xdlops.hpp:707
static constexpr index_t A_K1
Definition blockwise_gemm_xdlops.hpp:704
static constexpr auto b_thread_desc_
Definition blockwise_gemm_xdlops.hpp:1002
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_xdlops.hpp:724
static constexpr index_t NWaves
Definition blockwise_gemm_xdlops.hpp:699
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_xdlops.hpp:942
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_xdlops.hpp:945
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_xdlops.hpp:804
static constexpr index_t B_K0
Definition blockwise_gemm_xdlops.hpp:703
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_xdlops.hpp:760
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition blockwise_gemm_xdlops.hpp:802
static constexpr auto a_thread_desc_
Definition blockwise_gemm_xdlops.hpp:998
static constexpr auto c_thread_desc_
Definition blockwise_gemm_xdlops.hpp:1006
static constexpr index_t WaveSize
Definition blockwise_gemm_xdlops.hpp:700
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_xdlops.hpp:736
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_xdlops.hpp:862
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_xdlops.hpp:821
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_xdlops.hpp:747
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:889
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_xdlops.hpp:941
ThreadwiseTensorSliceTransfer_v4< FloatAB, FloatAB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_xdlops.hpp:1019
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_xdlops.hpp:696
static constexpr index_t MWaves
Definition blockwise_gemm_xdlops.hpp:698
static constexpr index_t B_K1
Definition blockwise_gemm_xdlops.hpp:705
static constexpr index_t KPerThread
Definition blockwise_gemm_xdlops.hpp:710
AThreadCopy a_thread_copy_
Definition blockwise_gemm_xdlops.hpp:1029
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition blockwise_gemm_xdlops.hpp:720
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_xdlops.hpp:789
BThreadCopy b_thread_copy_
Definition blockwise_gemm_xdlops.hpp:1030
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:848
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_xdlops.hpp:906
Definition blockwise_gemm_xdlops.hpp:429
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< BlockSize, FloatA, FloatB, FloatAcc, AK0MK1BlockDesc, BK0NK1BlockDesc, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, ComputeTypeA, ComputeTypeB > Base
Definition blockwise_gemm_xdlops.hpp:430
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
static __device__ constexpr index_t GetNumOfThread()
Definition thread_group.hpp:15
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition threadwise_tensor_slice_transfer.hpp:1260
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition threadwise_tensor_slice_transfer.hpp:1293
Definition xdlops_gemm.hpp:1821
Definition functional2.hpp:33
Definition dtype_vector.hpp:10