gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp Source File

gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp Source File#

Composable Kernel: gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp Source File
gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <string>
7#include <sstream>
8
9#include "ck_tile/core.hpp"
15
16namespace ck_tile {
17
18template <typename Problem, typename PipelinePolicy = GemmWPQuantPipelineAgBgCrPolicy>
20{
29
34
36 decltype(PipelinePolicy::template GetBlockWeightPreshuffleBQuant<Problem>())>;
37
38 static constexpr auto config =
39 BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
40
41 using WG = remove_cvref_t<decltype(config.template at<0>())>;
42
43 using Base::kKPerBlock;
44 using Base::kMPerBlock;
45 using Base::kNPerBlock;
46
50
51 using Base::BlockSize;
52
53 using Base::kPadK;
54 using Base::kPadM;
55 using Base::kPadN;
56
57 using Base::I0;
58 using Base::I1;
59 using Base::I2;
60
61 using Base::MWarp;
62 using Base::NWarp;
63
66
69
70 using Base::m_preload;
71
72 static constexpr index_t KPerBlockBQ =
73 integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
74 static constexpr index_t QScalesPerBlockRow =
75 integer_divide_ceil(kKPerBlock, QuantGroupSize::kK);
76
77 static constexpr index_t GetVectorSizeBQ()
78 {
79 return PipelinePolicy::template GetVectorSizeBQ<Problem>();
80 }
82
83 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
84 {
85 // clang-format off
86 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0);
87 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1);
88 return concat('_', "bquant_pipeline_AgBgCrV2_preshuffleB",
91 concat('x', WaveNumM, WaveNumN),
93 concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
94 // clang-format on
95 }
96
97 static constexpr bool PreshuffleB = Problem::PreshuffleB;
98 static constexpr auto TailNum = Problem::TailNum;
99
100 template <TailNumber TailNum,
101 typename ADramBlockWindowTmp,
102 typename BFlatBlockWindowTmp,
103 typename BQDramBlockWindowTmp,
104 typename AElementFunction,
105 index_t UnaryOpSize_ = 8>
106 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
107 const AElementFunction& a_element_func,
108 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
109 const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
110 index_t num_loop,
111 void* p_smem_ping,
112 void* p_smem_pong) const
113 {
114 static_assert(
115 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
116 std::is_same_v<BDataType, remove_cvref_t<typename BFlatBlockWindowTmp::DataType>> &&
117 std::is_same_v<BQDataType, remove_cvref_t<typename BQDramBlockWindowTmp::DataType>>,
118 "A/B/BQ Dram block window should have the same data type as appropriate "
119 "([A|B|BQ]DataType) defined in Problem definition!");
120
121 constexpr bool is_a_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
122 static_assert(!is_a_col_major, "A must be row major (col major not supported yet)");
123
124 constexpr bool is_bq_col_major = std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
125 static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
126
127 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
128 static_assert(!is_b_row_major, "B must be col major (row major not supported yet)");
129
130 const index_t iMWarp = get_warp_id() / NWarp;
131
132 __builtin_amdgcn_sched_barrier(0);
133
134 // A tile in LDS
135 ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
136 ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
137
138 constexpr auto a_lds_block_desc =
139 PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
140
141 auto a_lds_block_ping =
142 make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
143 auto a_lds_block_pong =
144 make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
145
146 // A DRAM tile window for load
147 auto a_copy_dram_window =
148 make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
150 a_dram_block_window_tmp.get_window_origin(),
151 PipelinePolicy::template MakeADramTileDistribution<Problem>());
152
153 auto a_copy_lds_window_ping =
154 make_tile_window(a_lds_block_ping,
156 {0, 0},
157 PipelinePolicy::template MakeADramTileDistribution<Problem>());
158
159 auto a_copy_lds_window_pong =
160 make_tile_window(a_lds_block_pong,
162 {0, 0},
163 PipelinePolicy::template MakeADramTileDistribution<Problem>());
164
165 // ping-pong window for A LDS
166 auto a_warp_window_ping_tmp =
167 make_tile_window(a_lds_block_ping,
169 {iMWarp * WG::kM, 0},
170 make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
171
172 auto a_warp_window_pong_tmp =
173 make_tile_window(a_lds_block_pong,
175 {iMWarp * WG::kM, 0},
176 make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
177
179 statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
181 a_warp_windows_ping;
182
184 statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
186 a_warp_windows_pong;
187
188 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
189 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
190 a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
191
192 move_tile_window(a_warp_windows_ping(mIter)(kIter),
193 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
194 });
195 });
196
197 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
198 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
199 a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
200
201 move_tile_window(a_warp_windows_pong(mIter)(kIter),
202 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
203 });
204 });
205
206 // Block GEMM
207 auto block_weight_preshuffle = BlockWeightPreshuffle();
208 // Acc register tile
209 auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
210
211 // B flat DRAM window for load
212 auto b_flat_distribution =
213 PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
214 auto b_flat_dram_window = // tile_window_with_static_distribution
216 b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
218 b_flat_dram_block_window_tmp.get_window_origin(),
219 b_flat_distribution);
220
221 using BTypeToUse =
222 std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
223 using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
224
225 // pingpong buffer for B
227 statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
229 b_flat_dram_windows;
230
232 b_warp_tensor_ping;
233
235 b_warp_tensor_pong;
236
237 // BQ DRAM window for load
238 auto bq_copy_dram_window =
239 make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
241 bq_dram_block_window_tmp.get_window_origin(),
242 PipelinePolicy::template MakeBQDramTileDistribution<Problem>());
243
244 // Prefetch A0
245 auto a_block_tile = load_tile(a_copy_dram_window);
246 // move A window to next k
247 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
248
249 // prefetch B
250 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
251 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
252 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
253
254 move_tile_window(b_flat_dram_windows(nIter)(kIter),
255 {nIter * flatNPerWarp, kIter * flatKPerWarp});
256
258 b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
259 });
260 });
261 // move B window to next flat K
262 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
263
264 // Strictly not needed given type deduction, but helps with readability
265 using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution());
266 using BQBlockTile =
267 decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
268
269 // Load tile 0 for BQ data directly into registers for block tile
270 BQBlockTile bq_block_tile, bq_block_tile_2;
271 bq_block_tile = load_tile(bq_copy_dram_window);
272 // move BQ to tile 1
273 move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
274
275 // Prefill A0
276 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
277 store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
278
279 __builtin_amdgcn_sched_barrier(0);
280
281 // Prefetch A1
282 a_block_tile = load_tile(a_copy_dram_window);
283 // move A window to next k
284 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
285
286 // initialize C
287 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
288
290
291 // preload A00,A10 from lds
292 statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
293 m_preload>
294 a_warp_tensor;
295
296 static_for<0, m_preload, 1>{}([&](auto loadIter) {
297 constexpr auto mIter = loadIter % MIterPerWarp;
298 constexpr auto kIter = loadIter / MIterPerWarp;
299 a_warp_tensor(loadIter) =
300 load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
301 });
302 __builtin_amdgcn_sched_barrier(0);
303
304 // MAIN LOOP
305 index_t iCounter = (num_loop - 1) / 2;
306 while(iCounter > 0)
307 {
308 // prefetch B(2i+1)
309 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
310 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
311 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
312
313 move_tile_window(b_flat_dram_windows(nIter)(kIter),
314 {nIter * flatNPerWarp, kIter * flatKPerWarp});
316 b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
317 });
318 });
319 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
320
321 bq_block_tile_2 = load_tile(bq_copy_dram_window);
322 move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
323
324 // Prefill A(2i+1)
325 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
326 store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
327
328 // Prefetch A(2i+2)
329 a_block_tile = load_tile(a_copy_dram_window);
330 // move A window to next k
331 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
332
333 // GEMM 2i
334 block_weight_preshuffle(c_block_tile,
335 a_warp_tensor,
336 b_warp_tensor_ping,
337 bq_block_tile,
338 a_warp_windows_ping);
339
340 static_for<0, m_preload, 1>{}([&](auto loadIter) {
341 constexpr auto mIter = loadIter % MIterPerWarp;
342 constexpr auto kIter = loadIter / MIterPerWarp;
343 a_warp_tensor(loadIter) =
344 load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
345 });
347
348 // Next K
349
350 // prefetch B(2i+2)
351 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
352 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
353 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
354
355 move_tile_window(b_flat_dram_windows(nIter)(kIter),
356 {nIter * flatNPerWarp, kIter * flatKPerWarp});
358 b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
359 });
360 });
361 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
362
363 bq_block_tile = load_tile(bq_copy_dram_window);
364 move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
365
366 // Prefill A(2i+2)
367 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
368 store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
369
370 // Prefetch A(2i+3)
371 a_block_tile = load_tile(a_copy_dram_window);
372 // move A window to next k
373 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
374
375 // GEMM 2i+1
376 block_weight_preshuffle(c_block_tile,
377 a_warp_tensor,
378 b_warp_tensor_pong,
379 bq_block_tile_2,
380 a_warp_windows_pong);
381
382 static_for<0, m_preload, 1>{}([&](auto loadIter) {
383 constexpr auto mIter = loadIter % MIterPerWarp;
384 constexpr auto kIter = loadIter / MIterPerWarp;
385 a_warp_tensor(loadIter) =
386 load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
387 });
389
390 iCounter--;
391 }
392
393 // tail
394 if constexpr(TailNum == TailNumber::Even)
395 {
396 // prefetch B(loopK)
397 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
398 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
399 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
400
401 move_tile_window(b_flat_dram_windows(nIter)(kIter),
402 {nIter * flatNPerWarp, kIter * flatKPerWarp});
403
405 b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
406 });
407 });
408 bq_block_tile_2 = load_tile(bq_copy_dram_window);
409
410 // Prefill A(loopK)
411 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
412 store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
413
414 // GEMM loopK-1
415 block_weight_preshuffle(c_block_tile,
416 a_warp_tensor,
417 b_warp_tensor_ping,
418 bq_block_tile,
419 a_warp_windows_ping);
420
421 static_for<0, m_preload, 1>{}([&](auto loadIter) {
422 constexpr auto mIter = loadIter % MIterPerWarp;
423 constexpr auto kIter = loadIter / MIterPerWarp;
424 a_warp_tensor(loadIter) =
425 load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
426 });
427
429
430 // GEMM loopK
431 block_weight_preshuffle(c_block_tile,
432 a_warp_tensor,
433 b_warp_tensor_pong,
434 bq_block_tile_2,
435 a_warp_windows_pong);
437 }
438 else if constexpr(TailNum == TailNumber::Odd)
439 {
440 // GEMM loopK
441 block_weight_preshuffle(c_block_tile,
442 a_warp_tensor,
443 b_warp_tensor_ping,
444 bq_block_tile,
445 a_warp_windows_ping);
447 }
448
449 return c_block_tile;
450 }
451
452 template <typename ADramBlockWindowTmp,
453 typename BFlatBlockWindowTmp,
454 typename BQDramBlockWindowTmp>
455 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
456 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
457 const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
458 index_t num_loop,
459 void* p_smem_ping,
460 void* p_smem_pong) const
461 {
462
463 return operator()<TailNum>(
464 a_dram_block_window_tmp,
465 [](const ADataType& a) { return a; },
466 b_flat_dram_block_window_tmp,
467 bq_dram_block_window_tmp,
468 num_loop,
469 p_smem_ping,
470 p_smem_pong);
471 }
472
473 template <typename ADramBlockWindowTmp,
474 typename BFlatBlockWindowTmp,
475 typename BQDramBlockWindowTmp>
476 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
477 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
478 const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
479 index_t num_loop,
480 TailNumber tail_number,
481 void* p_smem_ping,
482 void* p_smem_pong) const
483 {
484 const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
485 (void)bool_val; // Suppress unused parameter warning
486 constexpr auto tail_num = tail_num_.value;
487 return operator()<tail_num>(
488 a_dram_block_window_tmp,
489 [](const ADataType& a) { return a; },
490 b_flat_dram_block_window_tmp,
491 bq_dram_block_window_tmp,
492 num_loop,
493 p_smem_ping,
494 p_smem_pong);
495 };
496 return Base::TailHandler(RunPipeline, true, tail_number);
497 }
498};
499} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE void load_int4_tile(WarpTile &dst, const WarpWindow &src)
Definition load_interleaved_pk_type.hpp:46
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool, TailNumber tail_number)
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:35
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:20
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:33
remove_cvref_t< typename Problem::QuantGroupSize > QuantGroupSize
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:28
static constexpr bool PreshuffleB
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:97
static constexpr index_t KPerBlockBQ
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:72
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:26
static constexpr index_t GetVectorSizeBQ()
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:77
static constexpr index_t NWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:135
remove_cvref_t< decltype(config.template at< 0 >())> WG
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:41
static constexpr bool kPadN
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:118
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, const BQDramBlockWindowTmp &bq_dram_block_window_tmp, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:106
static constexpr bool kPadK
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:119
static constexpr index_t MIterPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:137
WeightPreshufflePipelineAGmemBGmemCRegV2< Problem > Base
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:21
static constexpr index_t m_preload
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:148
static constexpr index_t NIterPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:138
static constexpr bool kPadM
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:117
remove_cvref_t< typename Problem::ALayout > ALayout
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:30
remove_cvref_t< typename Problem::ADataType > ADataType
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:22
static CK_TILE_HOST const std::string GetName()
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:83
static constexpr auto TailNum
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:98
static constexpr index_t KIterPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:139
remove_cvref_t< typename Problem::BQLayout > BQLayout
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:32
remove_cvref_t< typename Problem::BLayout > BLayout
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:31
remove_cvref_t< typename Problem::BDataType > BDataType
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:23
static constexpr index_t flatKPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:98
static constexpr index_t KIterPerQScale
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:81
static constexpr auto I0
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:124
remove_cvref_t< typename Problem::BQDataType > BQDataType
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:24
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:25
static constexpr index_t kNPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:90
static constexpr index_t flatNPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:99
static constexpr index_t BlockSize
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:86
static constexpr index_t KPerBlockPerIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:145
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, const BQDramBlockWindowTmp &bq_dram_block_window_tmp, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:455
static constexpr index_t kKPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:91
static constexpr index_t MPerBlockPerIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:144
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:27
static constexpr auto config
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:38
static constexpr index_t kMPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:89
remove_cvref_t< decltype(PipelinePolicy::template GetBlockWeightPreshuffleBQuant< Problem >())> BlockWeightPreshuffle
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:35
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, const BQDramBlockWindowTmp &bq_dram_block_window_tmp, index_t num_loop, TailNumber tail_number, void *p_smem_ping, void *p_smem_pong) const
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:476
static constexpr index_t QScalesPerBlockRow
Definition gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp:74
static constexpr auto I1
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:125
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:54
static constexpr index_t GetVectorSizeB()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:107
static constexpr index_t GetVectorSizeA()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:102
static CK_TILE_HOST_DEVICE constexpr auto HotLoopScheduler()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:295
static constexpr index_t NWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:135
static constexpr bool kPadN
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:118
static constexpr index_t MWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:134
static constexpr bool kPadK
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:119
static constexpr index_t MIterPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:137
static constexpr index_t m_preload
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:148
static constexpr index_t NIterPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:138
static constexpr bool kPadM
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:117
static constexpr index_t KIterPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:139
static constexpr index_t flatKPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:98
static constexpr auto I0
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:124
static constexpr index_t kNPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:90
static constexpr index_t flatNPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:99
static constexpr index_t BlockSize
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:86
static CK_TILE_HOST_DEVICE constexpr auto LastHotLoopScheduler()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:490
static constexpr auto I2
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:126
static constexpr index_t KPerBlockPerIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:145
static constexpr index_t kKPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:91
static constexpr index_t MPerBlockPerIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:144
static constexpr index_t kMPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:89
static CK_TILE_HOST_DEVICE constexpr auto Last2ndHotLoopScheduler()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:436
static constexpr auto I1
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:125
Definition tile/core/utility/functional.hpp:43