batched_gemm_kernel.hpp Source File

batched_gemm_kernel.hpp Source File#

Composable Kernel: batched_gemm_kernel.hpp Source File
batched_gemm_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace ck_tile {
11
20{
21 CK_TILE_HOST explicit BatchedGemmHostArgs(const void* a_ptr_,
22 const void* b_ptr_,
23 void* c_ptr_,
24 ck_tile::index_t k_batch_,
28 ck_tile::index_t stride_A_,
29 ck_tile::index_t stride_B_,
30 ck_tile::index_t stride_C_,
31 ck_tile::index_t batch_stride_A_,
32 ck_tile::index_t batch_stride_B_,
33 ck_tile::index_t batch_stride_C_,
34 ck_tile::index_t batch_count_)
35 : UniversalGemmHostArgs<>({a_ptr_},
36 {b_ptr_},
37 {/*ds_ptr*/},
38 c_ptr_,
39 k_batch_,
40 M_,
41 N_,
42 K_,
43 {stride_A_},
44 {stride_B_},
45 {/*stride_Ds_*/},
46 stride_C_),
47 batch_stride_A(batch_stride_A_),
48 batch_stride_B(batch_stride_B_),
49 batch_stride_E(batch_stride_C_),
50 batch_count(batch_count_)
51 {
52 }
53
58};
59
60template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
62{
68
72
77
82
84 static_assert(
86 "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
87
89 static_assert(
91 "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
92
96 "C/CLayout and C/EDataType must be scalars.");
97
105
107
108 [[nodiscard]] CK_TILE_HOST static auto GetName() -> const std::string
109 {
110 // clang-format off
111 using P_ = GemmPipeline;
112 return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
113 concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
114 concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
115 concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
116 // clang-format on
117 }
118
119 CK_TILE_HOST static constexpr auto
120 GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3
121 {
122 return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
123 }
124
125 CK_TILE_HOST static auto BlockSize() -> dim3
126 {
128 {
129 return dim3(UniversalGemmKernel::kBlockSize / 2);
130 }
131 else
132 {
134 }
135 }
136
137 CK_TILE_HOST static constexpr BatchedGemmKernelArgs
139 {
140 return BatchedGemmKernelArgs{{hostArgs.as_ptr,
141 hostArgs.bs_ptr,
142 hostArgs.ds_ptr,
143 hostArgs.e_ptr,
144 hostArgs.M,
145 hostArgs.N,
146 hostArgs.K,
147 hostArgs.stride_As,
148 hostArgs.stride_Bs,
149 hostArgs.stride_Ds,
150 hostArgs.stride_E,
151 hostArgs.k_batch},
152 hostArgs.batch_stride_A,
153 hostArgs.batch_stride_B,
154 hostArgs.batch_stride_E,
155 hostArgs.batch_count};
156 }
157
159 {
160 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
161 }
162
163 CK_TILE_HOST static auto
165 {
166 if(kargs.batch_count < 1)
167 {
168 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
169 {
170 CK_TILE_ERROR("Conditions not met: batch_count must be at least 1 !");
171 }
172 return false;
173 }
174 if(kargs.batch_stride_A < 0 || kargs.batch_stride_A < kargs.M * kargs.K)
175 {
176 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
177 {
179 "Conditions not met: batch_stride_A must be non-negative and at least K * M!");
180 }
181 return false;
182 }
183 if(kargs.batch_stride_B < 0 || kargs.batch_stride_B < kargs.K * kargs.N)
184 {
185 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
186 {
188 "Conditions not met: batch_stride_B must be non-negative and at least K * N!");
189 }
190 return false;
191 }
192 if(kargs.batch_stride_E < 0 || kargs.batch_stride_E < kargs.M * kargs.N)
193 {
194 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
195 {
197 "Conditions not met: batch_stride_E must be non-negative and at least M * N!");
198 }
199 return false;
200 }
202 }
203
205 {
206 const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
207 const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
208 const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
209
210 const auto i_batch = amd_wave_read_first_lane(blockIdx.y);
211 const auto i_splitk = amd_wave_read_first_lane(blockIdx.z);
212
213 const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
214
215 // options
216 const auto batch_stride_A = amd_wave_read_first_lane(kargs.batch_stride_A);
217 const auto batch_offset_A = amd_wave_read_first_lane(i_batch * batch_stride_A);
218 const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + batch_offset_A +
219 splitk_batch_offset.as_k_split_offset[0];
220
221 const auto batch_stride_B = amd_wave_read_first_lane(kargs.batch_stride_B);
222 const auto batch_offset_B = amd_wave_read_first_lane(i_batch * batch_stride_B);
223 const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + batch_offset_B +
224 splitk_batch_offset.bs_k_split_offset[0];
225
226 const auto batch_stride_E = amd_wave_read_first_lane(kargs.batch_stride_E);
227 const auto batch_offset_C = amd_wave_read_first_lane(i_batch * batch_stride_E);
228 CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr) + batch_offset_C;
229
230 // allocate LDS
231 __shared__ char smem_ptr[GetSmemSize()];
232
234 {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
235 }
236};
237
238} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
bool EnvIsEnabled(EnvVar)
Definition tile/core/utility/env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
The Batched GEMM kernel host arguments.
Definition batched_gemm_kernel.hpp:20
ck_tile::index_t batch_stride_B
Definition batched_gemm_kernel.hpp:55
ck_tile::index_t batch_stride_A
Definition batched_gemm_kernel.hpp:54
ck_tile::index_t batch_stride_E
Definition batched_gemm_kernel.hpp:56
CK_TILE_HOST BatchedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ck_tile::index_t batch_stride_A_, ck_tile::index_t batch_stride_B_, ck_tile::index_t batch_stride_C_, ck_tile::index_t batch_count_)
Definition batched_gemm_kernel.hpp:21
ck_tile::index_t batch_count
Definition batched_gemm_kernel.hpp:57
ALayout and ADataType are expected to be scalars, not a tuple.
Definition batched_gemm_kernel.hpp:99
index_t batch_stride_E
Definition batched_gemm_kernel.hpp:102
index_t batch_count
Definition batched_gemm_kernel.hpp:103
index_t batch_stride_A
Definition batched_gemm_kernel.hpp:100
index_t batch_stride_B
Definition batched_gemm_kernel.hpp:101
Definition batched_gemm_kernel.hpp:62
static constexpr index_t kBlockSize
Definition batched_gemm_kernel.hpp:67
static CK_TILE_HOST auto IsSupportedArgument(const typename BatchedGemmKernel::KernelArgs &kargs) -> bool
Definition batched_gemm_kernel.hpp:164
static CK_TILE_HOST constexpr BatchedGemmKernelArgs MakeKernelArgs(const BatchedGemmHostArgs &hostArgs)
Definition batched_gemm_kernel.hpp:138
BatchedGemmKernelArgs KernelArgs
Definition batched_gemm_kernel.hpp:106
static CK_TILE_HOST constexpr auto GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3
Definition batched_gemm_kernel.hpp:120
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition batched_gemm_kernel.hpp:70
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition batched_gemm_kernel.hpp:75
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition batched_gemm_kernel.hpp:69
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition batched_gemm_kernel.hpp:71
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition batched_gemm_kernel.hpp:76
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition batched_gemm_kernel.hpp:65
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition batched_gemm_kernel.hpp:158
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition batched_gemm_kernel.hpp:80
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
Definition batched_gemm_kernel.hpp:204
static CK_TILE_HOST auto BlockSize() -> dim3
Definition batched_gemm_kernel.hpp:125
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, E and D.
Definition batched_gemm_kernel.hpp:79
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, E and D.
Definition batched_gemm_kernel.hpp:74
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition batched_gemm_kernel.hpp:81
static CK_TILE_HOST auto GetName() -> const std::string
Definition batched_gemm_kernel.hpp:108
The Universal GEMM kernel host arguments.
Definition universal_gemm_kernel.hpp:32
const std::array< index_t, NumDTensor > stride_Ds
Definition universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition universal_gemm_kernel.hpp:33
index_t K
Definition universal_gemm_kernel.hpp:70
void * e_ptr
Definition universal_gemm_kernel.hpp:65
index_t M
Definition universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition universal_gemm_kernel.hpp:71
index_t N
Definition universal_gemm_kernel.hpp:69
index_t stride_E
Definition universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition universal_gemm_kernel.hpp:61
index_t k_batch
Definition universal_gemm_kernel.hpp:80
Definition universal_gemm_kernel.hpp:325
std::array< index_t, NumATensor > as_k_split_offset
Definition universal_gemm_kernel.hpp:368
std::array< index_t, NumBTensor > bs_k_split_offset
Definition universal_gemm_kernel.hpp:369
The GEMM kernel device arguments.
Definition universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:94
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:88
index_t N
GEMM's N dimension size.
Definition universal_gemm_kernel.hpp:98
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:90
index_t M
GEMM's M dimension size.
Definition universal_gemm_kernel.hpp:96
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition universal_gemm_kernel.hpp:955
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition universal_gemm_kernel.hpp:373
static constexpr index_t kBlockSize
Definition universal_gemm_kernel.hpp:202
#define CK_TILE_ENV(name)
Definition tile/core/utility/env.hpp:145