blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MScaleBlock,
32 index_t NScaleBlock,
33 index_t KScaleBlock,
34 index_t MPerXDL,
35 index_t NPerXDL,
36 index_t MRepeat,
37 index_t NRepeat,
38 index_t KPacks>
42
43template <index_t BlockSize,
44 typename ADataType,
45 typename BDataType,
46 typename ComputeDataType,
47 typename AccDataType,
48 typename ATileDesc,
49 typename BTileDesc,
50 typename AMmaTileDesc,
51 typename BMmaTileDesc,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t BBlockTransferSrcScalarPerVector,
54 index_t MPerBlock,
55 index_t NPerBlock,
56 index_t KPerBlock,
57 index_t MScaleBlock,
58 index_t NScaleBlock,
59 index_t KScaleBlock,
60 index_t MPerXDL,
61 index_t NPerXDL,
62 index_t MRepeat,
63 index_t NRepeat,
64 index_t KPack
65 // ,bool TransposeC //disable transposec right now...
66 >
69 BlockSize,
70 ADataType,
71 BDataType,
72 ComputeDataType,
73 AccDataType,
74 ATileDesc,
75 BTileDesc,
76 AMmaTileDesc,
77 BMmaTileDesc,
78 ABlockTransferSrcScalarPerVector,
79 BBlockTransferSrcScalarPerVector,
80 MPerBlock,
81 NPerBlock,
82 KPerBlock,
83 MScaleBlock,
84 NScaleBlock,
85 KScaleBlock,
86 MPerXDL,
87 NPerXDL,
88 MRepeat,
89 NRepeat,
90 KPack> : BlockwiseGemmXdlops_pipeline_base<BlockSize,
91 ADataType,
92 BDataType,
93 ComputeDataType,
94 AccDataType,
95 ATileDesc,
96 BTileDesc,
97 AMmaTileDesc,
98 BMmaTileDesc,
99 ABlockTransferSrcScalarPerVector,
100 BBlockTransferSrcScalarPerVector,
101 MPerBlock,
102 NPerBlock,
103 KPerBlock,
104 MPerXDL,
105 NPerXDL,
106 MRepeat,
107 NRepeat,
108 KPack,
109 true>
110
111{
113 ADataType,
114 BDataType,
115 ComputeDataType,
116 AccDataType,
117 ATileDesc,
118 BTileDesc,
119 AMmaTileDesc,
120 BMmaTileDesc,
121 ABlockTransferSrcScalarPerVector,
122 BBlockTransferSrcScalarPerVector,
123 MPerBlock,
124 NPerBlock,
125 KPerBlock,
126 MPerXDL,
127 NPerXDL,
128 MRepeat,
129 NRepeat,
130 KPack,
131 true>;
132 using Base::A_K1;
133 using Base::B_K1;
134 using Base::I0;
135 using Base::I1;
136 using Base::I2;
137 using Base::KGroup;
138 using Base::KRepeat;
139 using Base::xdlops_gemm;
140 using typename Base::HotLoopInstList;
141
154 using Base::MWaves;
155 using Base::WaveSize;
156
157 static constexpr index_t PrefetchStages = 2;
158 static constexpr index_t PrefillStages = 1;
159 static constexpr index_t GlobalBufferNum = 1;
160 static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
161
162 template <typename TileDesc_M0_M1_M2_K>
163 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
164 {
165 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
166 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
167 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
168 constexpr index_t K2 = KPack / KGroup;
169 constexpr index_t K1 = WaveSize / NPerXDL;
170 constexpr index_t K0 = KRepeat * KGroup;
171
173 TileDesc_M0_M1_M2_K{},
181 }
182
183 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
185
186 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
187 {
188 return num_loop > PrefetchStages;
189 }
190
191 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
192 {
193 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
194 }
195
196 __device__ static constexpr auto HotLoopScheduler()
197 {
198 // A/B split schedule
199 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
200 constexpr auto num_ds_read_inst_a =
201 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
204
205 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
206
207 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
208 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
209
210 static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
211
212 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2;
213 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
214
215 constexpr auto ds_read_a_issue_cycle =
216 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
217 constexpr auto ds_read_a_mfma_rate =
218 math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
219
220 // constexpr auto num_dsread_a_mfma =
221 // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
222
223 constexpr auto num_total_stages = MRepeat;
224
225 // Group num_mfma_perstage num_ds_read_a_perstage
226 // since we want to reuse a local register buffer
227 constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
228 constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
229
230 constexpr auto num_ds_read_a_mfma_perstage =
231 math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
232
233 constexpr auto num_ds_read_a_prefetch_stages = 2;
234
235 constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
236 (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
237 constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
238 (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
239
240 constexpr auto buffer_load_stages_more =
241 (num_buffer_load_inst_a + num_buffer_load_inst_b) -
242 math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
243 (num_total_stages - 2)) *
244 ((num_total_stages - 2));
245
246 constexpr auto buffer_load_b_stages =
247 buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
248 ? num_buffer_load_inst_b / buffer_load_perstage_more
249 : (buffer_load_stages_more +
250 (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
251 buffer_load_perstage_less);
252
253 constexpr auto buffer_load_a_stages =
254 num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
255
256 constexpr auto buffer_load_issue_point_b = 0;
257 constexpr auto buffer_load_issue_point_interval_more =
258 num_mfma_perstage / buffer_load_perstage_more
259 ? num_mfma_perstage / buffer_load_perstage_more
260 : 1;
261 constexpr auto buffer_load_issue_point_interval_less =
262 num_mfma_perstage / buffer_load_perstage_less
263 ? num_mfma_perstage / buffer_load_perstage_less
264 : 1;
265 constexpr auto ds_write_issue_point = 0;
266 constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
267
268 // B global read
270 // Scale load, 1B
271 if constexpr(i.value == 0)
272 {
273 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
274 }
275 // Scale load, 1A
276 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
277
278 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
279 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
280
281 if constexpr(((i < buffer_load_stages_more) &&
282 (imfma % buffer_load_issue_point_interval_more ==
283 buffer_load_issue_point_b)) ||
284 ((i >= buffer_load_stages_more) &&
285 (imfma % buffer_load_issue_point_interval_less ==
286 buffer_load_issue_point_b)))
287 {
288 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
289 }
290
291 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
292 {
293 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
294 }
295 // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
296 });
297 // __builtin_amdgcn_sched_barrier(0);
298 });
299
300 // A global read + A local write
302 // Scale load, 1A
303 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
304 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
305 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
306 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
307 (imfma % buffer_load_issue_point_interval_more ==
308 ds_write_issue_point)) ||
309 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
310 (imfma % buffer_load_issue_point_interval_less ==
311 ds_write_issue_point)))
312 {
313 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
314 }
315 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
316 (imfma % buffer_load_issue_point_interval_more ==
317 buffer_load_issue_point_a)) ||
318 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
319 (imfma % buffer_load_issue_point_interval_less ==
320 buffer_load_issue_point_a)))
321 {
322 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
323 }
324 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
325 {
326 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
327 }
328 // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
329 });
330 // __builtin_amdgcn_sched_barrier(0);
331 });
332
333 // lds synchronization, prefetch next loop local A
335 ignore = i;
336 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
337 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
338 // Scale load, 1A
339 if constexpr(imfma == 0)
340 {
341 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
342 }
343
344 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
345 {
346 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
347 }
348 // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
349 });
350 // __builtin_amdgcn_sched_barrier(0);
351 });
352 }
353
354 template <bool HasMainLoop,
355 int NumKBlockPerScale,
356 TailNumber TailNum,
357 typename AGridDesc,
358 typename ABlockDesc,
359 typename ABlockTransfer,
360 typename AGridBuffer,
361 typename ABlockBuffer,
362 typename ABlockTransferStep,
363 typename BGridDesc,
364 typename BBlockDesc,
365 typename BBlockTransfer,
366 typename BGridBuffer,
367 typename BBlockBuffer,
368 typename BBlockTransferStep,
369 typename CScaleThreadDesc,
370 typename CThreadBuffer,
371 typename AScaleGridBuffer,
372 typename AScaleGridDesc,
373 typename AScaleThreadDesc,
374 typename AScaleThreadTransfer,
375 typename AScaleThreadTransferStep,
376 typename BScaleGridBuffer,
377 typename BScaleGridDesc,
378 typename BScaleThreadDesc,
379 typename BScaleThreadTransfer,
380 typename BScaleThreadTransferStep>
381 __device__ void Run(
382 // ABlockCopy
383 const AGridDesc& a_grid_desc,
384 const ABlockDesc& a_block_desc,
385 ABlockTransfer& a_blockwise_copy,
386 const AGridBuffer& a_grid_buf,
387 ABlockBuffer& a_block_buf,
388 const ABlockTransferStep& a_block_copy_step,
389 // BBlockCopy
390 const BGridDesc& b_grid_desc,
391 const BBlockDesc& b_block_desc,
392 BBlockTransfer& b_blockwise_copy,
393 BBlockTransfer& b_blockwise_copy_up,
394 const BGridBuffer& b_grid_buf,
395 const BGridBuffer& b_grid_buf_up,
396 BBlockBuffer& b_block_buf,
397 const BBlockTransferStep& b_block_copy_step,
398 // CThread
399 const CScaleThreadDesc& c_scale_thread_desc,
400 CThreadBuffer& c_thread_buf,
401 CThreadBuffer& c_thread_buf_up,
402 // AScaleThreadCopy
403 const AScaleGridDesc& a_scale_grid_desc,
404 const AScaleThreadDesc& a_scale_thread_desc,
405 AScaleThreadTransfer& a_scale_thread_copy,
406 const AScaleGridBuffer& a_scale_grid_buf,
407 const AScaleThreadTransferStep& a_scale_thread_copy_step,
408 // BScaleThreadCopy
409 const BScaleGridDesc& b_scale_grid_desc,
410 const BScaleThreadDesc& b_scale_thread_desc,
411 BScaleThreadTransfer& b_scale_thread_copy,
412 BScaleThreadTransfer& b_scale_thread_copy_up,
413 const BScaleGridBuffer& b_scale_grid_buf,
414 const BScaleGridBuffer& b_scale_grid_buf_up,
415 const BScaleThreadTransferStep& b_scale_thread_copy_step,
416 // num_loop
417 index_t num_loop) const
418 {
419 ignore = b_block_desc;
420 ignore = b_block_buf;
421 __builtin_amdgcn_sched_barrier(0);
422 static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1,
423 "Pipeline v3 only support scaleblocksliceK=1");
424 static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1,
425 "Pipeline v3 only support scaleblocksliceN=1");
426 // assume kperblock = scaleblockk
428 a_thread_desc_.GetElementSpaceSize());
430 b_thread_desc_.GetElementSpaceSize());
431
432 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
433 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
434 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
435
437 a_scale_thread_desc.GetElementSpaceSize());
439 b_scale_thread_desc.GetElementSpaceSize());
441 c_scale_thread_desc.GetElementSpaceSize());
443 c_scale_thread_desc.GetElementSpaceSize());
444
445 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
446 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
447 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs_up;
448 // StaticallyIndexedArray<decltype(c_scale_thread_buf), Number<2>{}> c_scale_thread_bufs;
449
450 // Global prefetch A1 B1, AScale1 BScale1
451 b_blockwise_copy.Run(b_grid_desc,
452 b_grid_buf,
454 b_block_origin_idx,
455 b_thread_bufs(I0));
456
457 b_blockwise_copy_up.Run(b_grid_desc,
458 b_grid_buf_up,
460 b_block_origin_idx,
461 b_thread_bufs_up(I0));
462 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
463 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
464
465 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
466 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
467 __builtin_amdgcn_sched_barrier(0);
468
469 a_scale_thread_copy.Run(a_scale_grid_desc,
470 a_scale_grid_buf,
471 a_scale_thread_desc,
472 make_tuple(I0, I0),
473 a_scale_thread_bufs(I0));
474
475 if constexpr(NumKBlockPerScale == 1)
476 {
477 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
478 a_scale_thread_copy_step.At(Number<1>{}));
479 }
480 else
481 {
482 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
483 a_scale_thread_copy_step.At(Number<0>{}));
484 }
485
486 b_scale_thread_copy.Run(b_scale_grid_desc,
487 b_scale_grid_buf,
488 b_scale_thread_desc,
489 make_tuple(I0, I0),
490 b_scale_thread_bufs(I0));
491
492 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
493
494 b_scale_thread_copy_up.Run(b_scale_grid_desc,
495 b_scale_grid_buf_up,
496 b_scale_thread_desc,
497 make_tuple(I0, I0),
498 b_scale_thread_bufs_up(I0));
499
500 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
501
502 static_for<0, MRepeat, 1>{}([&](auto m0) {
503 c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
504 });
505 static_for<0, MRepeat, 1>{}([&](auto m0) {
506 c_scale_thread_buf_up(m0) =
507 a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0];
508 });
509
510 // Local prefill A1
511 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
512
513 // Global prefetch A2, AScale2 BScale2
514 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
515 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
516
517 a_scale_thread_copy.Run(a_scale_grid_desc,
518 a_scale_grid_buf,
519 a_scale_thread_desc,
520 make_tuple(I0, I0),
521 a_scale_thread_bufs(I0));
522
523 if constexpr(NumKBlockPerScale == 1)
524 {
525 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
526 a_scale_thread_copy_step.At(Number<1>{}));
527 }
528 else
529 {
530 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
531 a_scale_thread_copy_step.At(Number<0>{}));
532 }
533
534 b_scale_thread_copy.Run(b_scale_grid_desc,
535 b_scale_grid_buf,
536 b_scale_thread_desc,
537 make_tuple(I0, I0),
538 b_scale_thread_bufs(I0));
539
540 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
541
542 b_scale_thread_copy_up.Run(b_scale_grid_desc,
543 b_scale_grid_buf_up,
544 b_scale_thread_desc,
545 make_tuple(I0, I0),
546 b_scale_thread_bufs_up(I0));
547
548 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
549
550 // Initialize C
551 c_thread_buf.Clear();
552 c_thread_buf_up.Clear();
553
554 // Double register buffer for non-scaled gemm computation
555 // 1. Reduce register pressure
556 // 2. Decouple the dependency between mfma instruction and scale-fma instruction following.
558 AccDataType,
559 1,
560 xdlops_gemm.GetRegSizePerXdlops(),
561 true>
562 c_thread_buf_per_scale;
564 AccDataType,
565 1,
566 xdlops_gemm.GetRegSizePerXdlops(),
567 true>
568 c_thread_buf_per_scale_up;
569
570 // Local prefetch A1
572 static_for<0, 2, 1>{}([&](auto m0) {
573 static_for<0, KRepeat, 1>{}([&](auto k0) {
574 static_for<0, KGroup, 1>{}([&](auto kg0) {
577 a_block_buf.At(I0),
579 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
580 a_thread_buf);
581 });
582 });
583 });
584
585 __builtin_amdgcn_sched_barrier(0);
586
587 // main body
588 if constexpr(HasMainLoop)
589 {
590 index_t i = 0;
591 do
592 {
593 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
594 b_blockwise_copy.Run(b_grid_desc,
595 b_grid_buf,
597 b_block_origin_idx,
598 b_thread_bufs(local_read_buf));
599 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
600 b_blockwise_copy_up.Run(b_grid_desc,
601 b_grid_buf_up,
603 b_block_origin_idx,
604 b_thread_bufs_up(local_read_buf));
605 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
606
607 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
608 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
609 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
610
611 a_scale_thread_copy.Run(a_scale_grid_desc,
612 a_scale_grid_buf,
613 a_scale_thread_desc,
614 make_tuple(I0, I0),
615 a_scale_thread_bufs(local_read_buf));
616
617 if constexpr(NumKBlockPerScale == 1)
618 {
619 a_scale_thread_copy.MoveSrcSliceWindow(
620 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
621 }
622 else
623 {
624 a_scale_thread_copy.MoveSrcSliceWindow(
625 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
626 }
627 b_scale_thread_copy.Run(b_scale_grid_desc,
628 b_scale_grid_buf,
629 b_scale_thread_desc,
630 make_tuple(I0, I0),
631 b_scale_thread_bufs(local_read_buf));
632
633 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
634 b_scale_thread_copy_step);
635
636 b_scale_thread_copy_up.Run(b_scale_grid_desc,
637 b_scale_grid_buf_up,
638 b_scale_thread_desc,
639 make_tuple(I0, I0),
640 b_scale_thread_bufs_up(local_read_buf));
641
642 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
643 b_scale_thread_copy_step);
644
645 static_for<0, MRepeat, 1>{}([&](auto m0) {
646 vector_type<AccDataType, 2> c_scale_thread_vec;
647 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
648 c_scale_thread_buf[m0];
649 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
650 c_scale_thread_buf[m0];
651 vector_type<AccDataType, 2> c_scale_thread_vec_up;
652 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
653 c_scale_thread_buf_up[m0];
654 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
655 c_scale_thread_buf_up[m0];
656
657 static_for<0, NRepeat, 1>{}([&](auto n0) {
658 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
659 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
660 .template AsType<AccDataType>()(Number<t>{}) = 0;
661 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
662 .template AsType<AccDataType>()(Number<t>{}) = 0;
663 });
664 static_for<0, KRepeat, 1>{}([&](auto k0) {
668
669 static_for<0, KPack, 1>{}([&](auto ik) {
670 a_thread_vec.template AsType<ComputeDataType>()(ik) =
671 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
672 make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
673 2,
674 I0,
675 I0,
676 k0,
677 I0,
678 ik))>{}];
679 b_thread_vec.template AsType<ComputeDataType>()(ik) =
680 b_thread_bufs[mfma_reg_buf]
681 [Number<b_thread_desc_.CalculateOffset(
682 make_tuple(n0, I0, k0, ik))>{}];
683
684 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
685 b_thread_bufs_up[mfma_reg_buf]
686 [Number<b_thread_desc_.CalculateOffset(
687 make_tuple(n0, I0, k0, ik))>{}];
688 });
689
690 using mfma_input_type =
691 typename vector_type<ComputeDataType,
692 xdlops_gemm.K1PerXdlops>::type;
693
694 xdlops_gemm.template Run<>(
695 a_thread_vec.template AsType<mfma_input_type>(),
696 b_thread_vec.template AsType<mfma_input_type>(),
697 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
698 xdlops_gemm.template Run<>(
699 a_thread_vec.template AsType<mfma_input_type>(),
700 b_thread_vec_up.template AsType<mfma_input_type>(),
701 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
702 });
703
704 constexpr index_t c_offset =
705 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
706
707 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
708 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
709
710 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
711 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
712 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
713 .template AsType<pk_fma_type>()[t],
714 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
715 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
716 .template AsType<pk_fma_type>()[t]);
717 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
718 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
719 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
720 .template AsType<pk_fma_type>()[t],
721 c_scale_thread_vec_up
722 .template AsType<pk_fma_type>()[Number<0>{}],
723 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
724 .template AsType<pk_fma_type>()[t]);
725 });
726 });
727
728 if constexpr(m0.value == (MRepeat - 2))
729 {
731
732 static_for<0, KRepeat, 1>{}([&](auto k0) {
733 static_for<0, KGroup, 1>{}([&](auto kg0) {
734 a_thread_copy_.Run(
736 make_tuple(Number<(m0 + 2) % MRepeat>{},
737 I0,
738 I0,
740 I0,
741 I0),
742 a_block_buf.At(local_read_buf),
745 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
746 2>{},
747 I0,
748 I0,
749 k0,
750 I0,
752 a_thread_buf);
753 });
754 });
755 }
756 else if constexpr(m0.value == (MRepeat - 1))
757 {
758 static_for<0, KRepeat, 1>{}([&](auto k0) {
759 static_for<0, KGroup, 1>{}([&](auto kg0) {
760 a_thread_copy_.Run(
762 make_tuple(Number<(m0 + 2) % MRepeat>{},
763 I0,
764 I0,
766 I0,
767 I0),
768 a_block_buf.At(local_read_buf),
771 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
772 2>{},
773 I0,
774 I0,
775 k0,
776 I0,
778 a_thread_buf);
779 });
780 });
781 }
782 else
783 {
784 static_for<0, KRepeat, 1>{}([&](auto k0) {
785 static_for<0, KGroup, 1>{}([&](auto kg0) {
786 a_thread_copy_.Run(
788 make_tuple(Number<(m0 + 2) % MRepeat>{},
789 I0,
790 I0,
792 I0,
793 I0),
794 a_block_buf.At(mfma_reg_buf),
797 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
798 2>{},
799 I0,
800 I0,
801 k0,
802 I0,
804 a_thread_buf);
805 });
806 });
807 }
808 });
809
810 static_for<0, MRepeat, 1>{}([&](auto m0) {
811 c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
812 b_scale_thread_bufs[mfma_reg_buf][I0];
813 c_scale_thread_buf_up(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
814 b_scale_thread_bufs_up[mfma_reg_buf][I0];
815 });
816
818 __builtin_amdgcn_sched_barrier(0);
819 };
820
821 LoopFunc(I0, I1);
822 LoopFunc(I1, I0);
823
824 i += 2;
825 } while(i < (num_loop - 2));
826 }
827
828 // tail
829 if constexpr(TailNum == TailNumber::Even)
830 {
831 b_blockwise_copy.Run(b_grid_desc,
832 b_grid_buf,
834 b_block_origin_idx,
835 b_thread_bufs(I1));
836 b_blockwise_copy_up.Run(b_grid_desc,
837 b_grid_buf_up,
839 b_block_origin_idx,
840 b_thread_bufs_up(I1));
841 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
842
843 static_for<0, MRepeat, 1>{}([&](auto m0) {
844 vector_type<AccDataType, 2> c_scale_thread_vec;
845 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
846 c_scale_thread_buf[m0];
847 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
848 c_scale_thread_buf[m0];
849 vector_type<AccDataType, 2> c_scale_thread_vec_up;
850 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
851 c_scale_thread_buf_up[m0];
852 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
853 c_scale_thread_buf_up[m0];
854
855 static_for<0, NRepeat, 1>{}([&](auto n0) {
856 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
857 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
858 .template AsType<AccDataType>()(Number<t>{}) = 0;
859 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
860 .template AsType<AccDataType>()(Number<t>{}) = 0;
861 });
862 static_for<0, KRepeat, 1>{}([&](auto k0) {
866
867 static_for<0, KPack, 1>{}([&](auto ik) {
868 a_thread_vec.template AsType<ComputeDataType>()(ik) =
869 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
870 make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
871 b_thread_vec.template AsType<ComputeDataType>()(ik) =
872 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
873 make_tuple(n0, I0, k0, ik))>{}];
874 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
875 b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
876 make_tuple(n0, I0, k0, ik))>{}];
877 });
878
879 using mfma_input_type =
880 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
881
882 xdlops_gemm.template Run<>(
883 a_thread_vec.template AsType<mfma_input_type>(),
884 b_thread_vec.template AsType<mfma_input_type>(),
885 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
886 xdlops_gemm.template Run<>(
887 a_thread_vec.template AsType<mfma_input_type>(),
888 b_thread_vec_up.template AsType<mfma_input_type>(),
889 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
890 });
891
892 constexpr index_t c_offset =
893 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
894
895 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
896 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
897
898 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
899 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
900 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
901 .template AsType<pk_fma_type>()[t],
902 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
903 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
904 .template AsType<pk_fma_type>()[t]);
905 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
906 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
907 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
908 .template AsType<pk_fma_type>()[t],
909 c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
910 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
911 .template AsType<pk_fma_type>()[t]);
912 });
913 });
914
915 if constexpr(m0.value == (MRepeat - 2))
916 {
918
919 static_for<0, KRepeat, 1>{}([&](auto k0) {
920 static_for<0, KGroup, 1>{}([&](auto kg0) {
921 a_thread_copy_.Run(
923 make_tuple(Number<(m0 + 2) % MRepeat>{},
924 I0,
925 I0,
927 I0,
928 I0),
929 a_block_buf.At(I1),
932 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
933 a_thread_buf);
934 });
935 });
936 }
937 else if constexpr(m0.value == (MRepeat - 1))
938 {
939 static_for<0, KRepeat, 1>{}([&](auto k0) {
940 static_for<0, KGroup, 1>{}([&](auto kg0) {
941 a_thread_copy_.Run(
943 make_tuple(Number<(m0 + 2) % MRepeat>{},
944 I0,
945 I0,
947 I0,
948 I0),
949 a_block_buf.At(I1),
952 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
953 a_thread_buf);
954 });
955 });
956 }
957 else
958 {
959 static_for<0, KRepeat, 1>{}([&](auto k0) {
960 static_for<0, KGroup, 1>{}([&](auto kg0) {
961 a_thread_copy_.Run(
963 make_tuple(Number<(m0 + 2) % MRepeat>{},
964 I0,
965 I0,
967 I0,
968 I0),
969 a_block_buf.At(I0),
972 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
973 a_thread_buf);
974 });
975 });
976 }
977 });
978
980
981 static_for<0, MRepeat, 1>{}([&](auto m0) {
982 c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
983 c_scale_thread_buf_up(m0) =
984 a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0];
985 });
986
987 static_for<0, MRepeat, 1>{}([&](auto m0) {
988 vector_type<AccDataType, 2> c_scale_thread_vec;
989 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
990 c_scale_thread_buf[m0];
991 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
992 c_scale_thread_buf[m0];
993 vector_type<AccDataType, 2> c_scale_thread_vec_up;
994 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
995 c_scale_thread_buf_up[m0];
996 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
997 c_scale_thread_buf_up[m0];
998
999 static_for<0, NRepeat, 1>{}([&](auto n0) {
1000 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
1001 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
1002 .template AsType<AccDataType>()(Number<t>{}) = 0;
1003 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
1004 .template AsType<AccDataType>()(Number<t>{}) = 0;
1005 });
1006 static_for<0, KRepeat, 1>{}([&](auto k0) {
1010
1011 static_for<0, KPack, 1>{}([&](auto ik) {
1012 a_thread_vec.template AsType<ComputeDataType>()(ik) =
1013 a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
1014 (m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
1015 b_thread_vec.template AsType<ComputeDataType>()(ik) =
1016 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
1017 make_tuple(n0, I0, k0, ik))>{}];
1018 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
1019 b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
1020 make_tuple(n0, I0, k0, ik))>{}];
1021 });
1022
1023 using mfma_input_type =
1024 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
1025
1026 xdlops_gemm.template Run<>(
1027 a_thread_vec.template AsType<mfma_input_type>(),
1028 b_thread_vec.template AsType<mfma_input_type>(),
1029 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
1030 xdlops_gemm.template Run<>(
1031 a_thread_vec.template AsType<mfma_input_type>(),
1032 b_thread_vec_up.template AsType<mfma_input_type>(),
1033 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
1034 });
1035 constexpr index_t c_offset =
1036 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1037
1038 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
1039 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
1040
1041 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
1042 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
1043 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
1044 .template AsType<pk_fma_type>()[t],
1045 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
1046 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
1047 .template AsType<pk_fma_type>()[t]);
1048
1049 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
1050 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
1051 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
1052 .template AsType<pk_fma_type>()[t],
1053 c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
1054 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
1055 .template AsType<pk_fma_type>()[t]);
1056 });
1057 });
1058
1059 if constexpr(m0.value < (MRepeat - 2))
1060 {
1061 static_for<0, KRepeat, 1>{}([&](auto k0) {
1062 static_for<0, KGroup, 1>{}([&](auto kg0) {
1063 a_thread_copy_.Run(
1065 make_tuple(
1067 a_block_buf.At(I1),
1069 make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
1070 I0,
1071 I0,
1072 k0,
1073 I0,
1075 a_thread_buf);
1076 });
1077 });
1078 }
1079 });
1080 // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
1081 // latency
1082 // // __builtin_amdgcn_sched_barrier(0);
1083 }
1084 else
1085 {
1086 static_for<0, MRepeat, 1>{}([&](auto m0) {
1087 vector_type<AccDataType, 2> c_scale_thread_vec;
1088 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
1089 c_scale_thread_buf[m0];
1090 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
1091 c_scale_thread_buf[m0];
1092 vector_type<AccDataType, 2> c_scale_thread_vec_up;
1093 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
1094 c_scale_thread_buf_up[m0];
1095 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
1096 c_scale_thread_buf_up[m0];
1097
1098 static_for<0, NRepeat, 1>{}([&](auto n0) {
1099 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
1100 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
1101 .template AsType<AccDataType>()(Number<t>{}) = 0;
1102 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
1103 .template AsType<AccDataType>()(Number<t>{}) = 0;
1104 });
1105 static_for<0, KRepeat, 1>{}([&](auto k0) {
1109
1110 static_for<0, KPack, 1>{}([&](auto ik) {
1111 a_thread_vec.template AsType<ComputeDataType>()(ik) =
1112 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1113 make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
1114 b_thread_vec.template AsType<ComputeDataType>()(ik) =
1115 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1116 make_tuple(n0, I0, k0, ik))>{}];
1117 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
1118 b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
1119 make_tuple(n0, I0, k0, ik))>{}];
1120 });
1121
1122 using mfma_input_type =
1123 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
1124
1125 xdlops_gemm.template Run<>(
1126 a_thread_vec.template AsType<mfma_input_type>(),
1127 b_thread_vec.template AsType<mfma_input_type>(),
1128 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
1129 xdlops_gemm.template Run<>(
1130 a_thread_vec.template AsType<mfma_input_type>(),
1131 b_thread_vec_up.template AsType<mfma_input_type>(),
1132 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
1133 });
1134 constexpr index_t c_offset =
1135 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1136
1137 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
1138 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
1139
1140 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
1141 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
1142 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
1143 .template AsType<pk_fma_type>()[t],
1144 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
1145 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
1146 .template AsType<pk_fma_type>()[t]);
1147 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
1148 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
1149 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
1150 .template AsType<pk_fma_type>()[t],
1151 c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
1152 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
1153 .template AsType<pk_fma_type>()[t]);
1154 });
1155 });
1156
1157 if constexpr(m0.value < (MRepeat - 2))
1158 {
1159 static_for<0, KRepeat, 1>{}([&](auto k0) {
1160 static_for<0, KGroup, 1>{}([&](auto kg0) {
1161 a_thread_copy_.Run(
1163 make_tuple(
1165 a_block_buf.At(I0),
1167 make_tuple(
1168 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
1169 a_thread_buf);
1170 });
1171 });
1172 }
1173 });
1174 }
1175 }
1176
1177 protected:
1178 // MRepeat MWave MLane KRepeat KLane KPack
1179 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
1180 // Reduce the vgpr usage here.
1183
1185 ComputeDataType,
1187 decltype(a_thread_desc_),
1188 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
1190 5,
1191 A_K1,
1192 A_K1>;
1193
1195
1198
1199 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
1200
1202};
1203
1204} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Vgpr
Definition amd_address_space.hpp:20
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const CScaleThreadDesc &c_scale_thread_desc, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, const AScaleGridDesc &a_scale_grid_desc, const AScaleThreadDesc &a_scale_thread_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const AScaleThreadTransferStep &a_scale_thread_copy_step, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, BScaleThreadTransfer &b_scale_thread_copy_up, const BScaleGridBuffer &b_scale_grid_buf, const BScaleGridBuffer &b_scale_grid_buf_up, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp:381
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, true > Base
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp:112
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp:1184
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp:40
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10