device_gemm_bias_add_reduce_xdl_cshuffle.hpp Source File

device_gemm_bias_add_reduce_xdl_cshuffle.hpp Source File#

Composable Kernel: device_gemm_bias_add_reduce_xdl_cshuffle.hpp Source File
device_gemm_bias_add_reduce_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
24// version currently has compiler issues with register spill which further causes validation
25// failures.
26template <typename ALayout,
27 typename BLayout,
28 typename CLayout,
29 typename ADataType,
30 typename BDataType,
31 typename CDataType,
32 typename BiasDataType,
33 typename D0DataType,
34 typename GemmAccDataType,
35 typename CShuffleDataType,
36 typename ReduceAccDataType,
37 typename ReducePtrsGlobal,
38 typename AElementwiseOperation,
39 typename BElementwiseOperation,
40 typename CElementwiseOperation,
41 typename D0ElementwiseOperation,
42 typename ReduceOperations,
43 typename ReduceInElementwiseOperations,
44 typename ReduceAccElementwiseOperations,
45 typename ReduceGlobalMemoryDataOperation,
46 GemmSpecialization GemmSpec,
47 index_t NumGemmKPrefetchStage,
48 index_t BlockSize,
49 index_t MPerBlock,
50 index_t NPerBlock,
51 index_t KPerBlock,
52 index_t AK1,
53 index_t BK1,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MXdlPerWave,
57 index_t NXdlPerWave,
58 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
59 typename ABlockTransferThreadClusterArrangeOrder,
60 typename ABlockTransferSrcAccessOrder,
61 index_t ABlockTransferSrcVectorDim,
62 index_t ABlockTransferSrcScalarPerVector,
63 index_t ABlockTransferDstScalarPerVector_AK1,
64 bool ABlockLdsExtraM,
65 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
66 typename BBlockTransferThreadClusterArrangeOrder,
67 typename BBlockTransferSrcAccessOrder,
68 index_t BBlockTransferSrcVectorDim,
69 index_t BBlockTransferSrcScalarPerVector,
70 index_t BBlockTransferDstScalarPerVector_BK1,
71 bool BBlockLdsExtraN,
72 index_t CShuffleMXdlPerWavePerShuffle,
73 index_t CShuffleNXdlPerWavePerShuffle,
74 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
75 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
76 typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
77 index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
78 index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
80struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceOperations::Size()>
81{
83
85 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
86 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
87
88 static constexpr auto I0 = Number<0>{};
89 static constexpr auto I1 = Number<1>{};
90 static constexpr auto I2 = Number<2>{};
91
92 static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
93 {
94 const auto a_grid_desc_mraw_kraw = [&]() {
96 {
98 make_tuple(StrideA, I1));
99 }
101 {
102 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
103 make_tuple(I1, StrideA));
104 }
105 }();
106
107 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
108 const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
109
110 const auto MPad = M - MRaw;
111 const auto KPad = K - KRaw;
112
113 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
115 {
116 // pad both M and K
117 assert(K % AK1 == 0);
118
119 const auto AK0 = K / AK1;
120
121 const auto a_grid_desc_m_k =
122 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
124 make_right_pad_transform(KRaw, KPad)),
127
128 const auto a_grid_desc_ak0_m_ak1 =
129 transform_tensor_descriptor(a_grid_desc_m_k,
134
135 return a_grid_desc_ak0_m_ak1;
136 }
137 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
139 {
140 // pad M, but not K
141 assert(KRaw % AK1 == 0);
142
143 const auto AK0 = KRaw / AK1;
144
145 const auto a_grid_desc_ak0_m_ak1 =
146 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
148 make_right_pad_transform(MRaw, MPad)),
151
152 return a_grid_desc_ak0_m_ak1;
153 }
154 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
156 {
157 // pad K, but not M
158 assert(K % AK1 == 0);
159
160 const auto AK0 = K / AK1;
161
162 const auto a_grid_desc_m_k = transform_tensor_descriptor(
163 a_grid_desc_mraw_kraw,
167
168 const auto a_grid_desc_ak0_m_ak1 =
169 transform_tensor_descriptor(a_grid_desc_m_k,
174
175 return a_grid_desc_ak0_m_ak1;
176 }
177 else
178 {
179 // not pad M or K
180 assert(KRaw % AK1 == 0);
181
182 const auto AK0 = KRaw / AK1;
183
184 const auto a_grid_desc_ak0_m_ak1 =
185 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
190
191 return a_grid_desc_ak0_m_ak1;
192 }
193 }
194
195 static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
196 {
197 const auto b_grid_desc_nraw_kraw = [&]() {
199 {
200 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
201 make_tuple(I1, StrideB));
202 }
204 {
205 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
206 make_tuple(StrideB, I1));
207 }
208 }();
209
210 const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
211 const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
212
213 const auto NPad = N - NRaw;
214 const auto KPad = K - KRaw;
215
216 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
218 {
219 // pad both N and K
220 assert(K % BK1 == 0);
221
222 const auto BK0 = K / BK1;
223
224 const auto b_grid_desc_n_k =
225 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
227 make_right_pad_transform(KRaw, KPad)),
230
231 const auto b_grid_desc_bk0_n_bk1 =
232 transform_tensor_descriptor(b_grid_desc_n_k,
237
238 return b_grid_desc_bk0_n_bk1;
239 }
240 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
242 {
243 // pad N, but not K
244 assert(KRaw % BK1 == 0);
245
246 const auto BK0 = KRaw / BK1;
247
248 const auto b_grid_desc_bk0_n_bk1 =
249 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
251 make_right_pad_transform(NRaw, NPad)),
254
255 return b_grid_desc_bk0_n_bk1;
256 }
257 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
259 {
260 // pad K, but not N
261 assert(K % BK1 == 0);
262
263 const auto BK0 = K / BK1;
264
265 const auto b_grid_desc_n_k = transform_tensor_descriptor(
266 b_grid_desc_nraw_kraw,
270
271 const auto b_grid_desc_bk0_n_bk1 =
272 transform_tensor_descriptor(b_grid_desc_n_k,
277
278 return b_grid_desc_bk0_n_bk1;
279 }
280 else
281 {
282 // not pad N or K
283 assert(KRaw % BK1 == 0);
284
285 const auto BK0 = KRaw / BK1;
286
287 const auto b_grid_desc_bk0_n_bk1 =
288 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
293
294 return b_grid_desc_bk0_n_bk1;
295 }
296 }
297
298 static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
299 {
300 const auto c_grid_desc_mraw_nraw = [&]() {
302 {
303 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
304 make_tuple(StrideC, I1));
305 }
307 {
308 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
309 make_tuple(I1, StrideC));
310 }
311 }();
312
313 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
314 const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
315
316 const auto MPad = M - MRaw;
317 const auto NPad = N - NRaw;
318
319 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
321 {
322 // pad M and N
323 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
325 make_right_pad_transform(NRaw, NPad)),
328 }
329 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
331 {
332 // pad M, but not N
334 c_grid_desc_mraw_nraw,
338 }
339 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
341 {
342 // pad N, but not M
344 c_grid_desc_mraw_nraw,
348 }
349 else
350 {
351 // not pad M or N
352 return c_grid_desc_mraw_nraw;
353 }
354 }
355
356 // assume D is packed tensor
358 {
359 const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
360
361 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
362 const auto MPad = M - MRaw;
363
364 if constexpr(GemmSpec == GemmSpecialization::MPadding ||
365 GemmSpec == GemmSpecialization::MNPadding ||
366 GemmSpec == GemmSpecialization::MKPadding ||
368 {
369 // pad M
370 return transform_tensor_descriptor(d_grid_desc_mraw,
374 }
375 else
376 {
377 // not pad M
378 return d_grid_desc_mraw;
379 }
380 }
381
384 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
385 using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0));
386 using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
388
389 // GridwiseGemm
390 template <index_t NXdlPerWave_>
392 ADataType, // TODO: distinguish A/B datatype
393 GemmAccDataType,
394 CShuffleDataType,
395 CDataType,
396 BiasDataType,
397 D0DataType,
398 ReduceAccDataType,
399 ReducePtrsGlobal,
400 AElementwiseOperation,
401 BElementwiseOperation,
402 CElementwiseOperation,
403 D0ElementwiseOperation,
404 ReduceOperations,
405 ReduceInElementwiseOperations,
406 ReduceAccElementwiseOperations,
408 ReduceGlobalMemoryDataOperation,
415 NumGemmKPrefetchStage,
416 BlockSize,
417 MPerBlock,
418 NPerBlock,
419 KPerBlock,
420 AK1,
421 BK1,
422 MPerXDL,
423 NPerXDL,
424 MXdlPerWave,
425 NXdlPerWave_,
426 ABlockTransferThreadClusterLengths_AK0_M_AK1,
427 ABlockTransferThreadClusterArrangeOrder,
428 ABlockTransferSrcAccessOrder,
429 ABlockTransferSrcVectorDim,
430 ABlockTransferSrcScalarPerVector,
431 ABlockTransferDstScalarPerVector_AK1,
432 false,
433 ABlockLdsExtraM,
434 BBlockTransferThreadClusterLengths_BK0_N_BK1,
435 BBlockTransferThreadClusterArrangeOrder,
436 BBlockTransferSrcAccessOrder,
437 BBlockTransferSrcVectorDim,
438 BBlockTransferSrcScalarPerVector,
439 BBlockTransferDstScalarPerVector_BK1,
440 false,
441 BBlockLdsExtraN,
442 CShuffleMXdlPerWavePerShuffle,
443 CShuffleNXdlPerWavePerShuffle,
444 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
445 CShuffleBlockTransferScalarPerVector_NPerBlock,
446 CReduceThreadClusterLengths_MPerBlock_NPerBlock,
447 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
448 CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
449 LoopSched>;
452
453 // Argument
454 struct Argument : public BaseArgument
455 {
456 Argument(const ADataType* p_a_grid,
457 const BDataType* p_b_grid,
458 CDataType* p_c_grid,
459 const BiasDataType* p_bias_grid,
460 const D0DataType* p_d0_grid,
461 ReducePtrsGlobal p_reduces_grid,
462 index_t MRaw,
463 index_t NRaw,
464 index_t KRaw,
465 index_t StrideA,
466 index_t StrideB,
467 index_t StrideC,
468 index_t StrideC1,
469 AElementwiseOperation a_element_op,
470 BElementwiseOperation b_element_op,
471 CElementwiseOperation c_element_op,
472 D0ElementwiseOperation d0_element_op,
473 ReduceInElementwiseOperations reduce_in_element_ops,
474 ReduceAccElementwiseOperations reduce_out_element_ops)
475 : p_a_grid_{p_a_grid},
476 p_b_grid_{p_b_grid},
477 p_c_grid_{p_c_grid},
478 p_bias_grid_{p_bias_grid},
479 p_d0_grid_{p_d0_grid},
480 p_reduces_grid_{p_reduces_grid},
485 c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)},
487 block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
488 a_element_op_{a_element_op},
489 b_element_op_{b_element_op},
490 c_element_op_{c_element_op},
491 d0_element_op_{d0_element_op},
492 reduce_in_element_ops_{reduce_in_element_ops},
493 reduce_out_element_ops_{reduce_out_element_ops}
494 {
495 }
496
497 // private:
498 const ADataType* p_a_grid_;
499 const BDataType* p_b_grid_;
500 CDataType* p_c_grid_;
501 const BiasDataType* p_bias_grid_;
502 const D0DataType* p_d0_grid_;
503 ReducePtrsGlobal p_reduces_grid_;
511 AElementwiseOperation a_element_op_;
512 BElementwiseOperation b_element_op_;
513 CElementwiseOperation c_element_op_;
514 D0ElementwiseOperation d0_element_op_;
515 ReduceInElementwiseOperations reduce_in_element_ops_;
516 ReduceAccElementwiseOperations reduce_out_element_ops_;
517 };
518
519 // Invoker
520 struct Invoker : public BaseInvoker
521 {
523
524 template <typename GridwiseGemm>
525 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
526 {
527 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
531 {
532 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
533 }
534 auto c_grid_desc_mblock_mperblock_nblock_nperblock =
535 GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
536 arg.c_grid_desc_m_n_);
537
538 auto c0_grid_desc_mblock_mperblock_nblock_nperblock =
539 GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
541
542 auto c1_grid_desc_mblock_mperblock_nblock_nperblock =
543 GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
545
546 auto reduce_grid_desc_mblock_mperblock =
547 GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(arg.reduce_grid_desc_m_);
548 const index_t grid_size =
549 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
550
551 const auto K =
552 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
553
554 float elapsed_time = 0.0f;
555 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
556 {
558 GridwiseGemm,
559 ADataType, // TODO: distiguish A/B datatype
560 CDataType,
561 BiasDataType,
562 D0DataType,
563 ReducePtrsGlobal,
564 AElementwiseOperation,
565 BElementwiseOperation,
566 CElementwiseOperation,
567 D0ElementwiseOperation,
568 ReduceInElementwiseOperations,
569 ReduceAccElementwiseOperations,
572 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
573 typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
574 typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
575 typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
576 typename GridwiseGemm::DefaultBlock2CTileMap,
577 true>;
578
579 elapsed_time =
580 launch_and_time_kernel(stream_config,
581 kernel,
582 dim3(grid_size),
583 dim3(BlockSize),
584 0,
585 arg.p_a_grid_,
586 arg.p_b_grid_,
587 arg.p_c_grid_,
588 arg.p_bias_grid_,
589 arg.p_d0_grid_,
590 arg.p_reduces_grid_,
591 arg.a_element_op_,
592 arg.b_element_op_,
593 arg.c_element_op_,
594 arg.d0_element_op_,
599 c_grid_desc_mblock_mperblock_nblock_nperblock,
600 c0_grid_desc_mblock_mperblock_nblock_nperblock,
601 c1_grid_desc_mblock_mperblock_nblock_nperblock,
602 reduce_grid_desc_mblock_mperblock,
604 }
605 else
606 {
608 GridwiseGemm,
609 ADataType, // TODO: distiguish A/B datatype
610 CDataType,
611 BiasDataType,
612 D0DataType,
613 ReducePtrsGlobal,
614 AElementwiseOperation,
615 BElementwiseOperation,
616 CElementwiseOperation,
617 D0ElementwiseOperation,
618 ReduceInElementwiseOperations,
619 ReduceAccElementwiseOperations,
622 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
623 typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
624 typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
625 typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
626 typename GridwiseGemm::DefaultBlock2CTileMap,
627 false>;
628
629 elapsed_time =
630 launch_and_time_kernel(stream_config,
631 kernel,
632 dim3(grid_size),
633 dim3(BlockSize),
634 0,
635 arg.p_a_grid_,
636 arg.p_b_grid_,
637 arg.p_c_grid_,
638 arg.p_bias_grid_,
639 arg.p_d0_grid_,
640 arg.p_reduces_grid_,
641 arg.a_element_op_,
642 arg.b_element_op_,
643 arg.c_element_op_,
644 arg.d0_element_op_,
649 c_grid_desc_mblock_mperblock_nblock_nperblock,
650 c0_grid_desc_mblock_mperblock_nblock_nperblock,
651 c1_grid_desc_mblock_mperblock_nblock_nperblock,
652 reduce_grid_desc_mblock_mperblock,
654 }
655
656 return elapsed_time;
657 }
658
660
661 // polymorphic
662 float Run(const BaseArgument* p_arg,
663 const StreamConfig& stream_config = StreamConfig{}) override
664 {
665 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
666 }
667 };
668
669 static constexpr bool IsValidCompilationParameter()
670 {
671 // TODO: properly implement this check
672 return true;
673 }
674
675 static bool IsSupportedArgument(const Argument& arg)
676 {
678 {
679 return false;
680 }
681 if(get_warp_size() == 64)
682 {
683 if constexpr(NXdlPerWave64 > 0)
684 {
689 }
690 }
691 else
692 {
693 if constexpr(NXdlPerWave32 > 0)
694 {
699 }
700 }
701 return false;
702 }
703
704 // polymorphic
705 bool IsSupportedArgument(const BaseArgument* p_arg) override
706 {
707 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
708 }
709
710 static constexpr int NumReduce = ReduceOperations::Size();
711 static auto MakeArgument(const void* p_a,
712 const void* p_b,
713 const void* p_bias,
714 std::array<const void*, 1> p_ds,
715 void* p_c,
716 std::array<void*, NumReduce> p_reduces,
717 ck::index_t M,
718 ck::index_t N,
719 ck::index_t K,
720 ck::index_t StrideA,
721 ck::index_t StrideB,
722 ck::index_t StrideC,
723 std::array<ck::index_t, 1> StrideDs,
724 std::array<void*, 3> gemm_element_ops,
725 std::array<void*, 1> d_element_ops,
726 std::array<void*, NumReduce> reduce_in_element_op,
727 std::array<void*, NumReduce> reduce_out_element_op)
728 {
729 ReducePtrsGlobal reduce_tuple = generate_tuple(
730 [&](auto I) {
731 auto tmp = ReducePtrsGlobal{}[I];
732 using T = remove_pointer_t<decltype(tmp)>;
733 return static_cast<T*>(p_reduces[I]);
734 },
736
737 ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
738 [&](auto I) {
739 auto tmp = ReduceInElementwiseOperations{}[I];
740 using T = remove_pointer_t<decltype(tmp)>;
741 return *(static_cast<T*>(reduce_in_element_op[I]));
742 },
744 ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
745 [&](auto I) {
746 auto tmp = ReduceAccElementwiseOperations{}[I];
747 using T = remove_pointer_t<decltype(tmp)>;
748 return *(static_cast<T*>(reduce_out_element_op[I]));
749 },
751
752 AElementwiseOperation a_element_op =
753 *(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
754 BElementwiseOperation b_element_op =
755 *(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
756 CElementwiseOperation c_element_op =
757 *(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
758 D0ElementwiseOperation d_element_op =
759 *(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
760
761 return Argument{static_cast<const ADataType*>(p_a),
762 static_cast<const BDataType*>(p_b),
763 static_cast<CDataType*>(p_c),
764 static_cast<const BiasDataType*>(p_bias),
765 static_cast<const D0DataType*>(p_ds[0]),
766 reduce_tuple,
767 M,
768 N,
769 K,
770 StrideA,
771 StrideB,
772 StrideC,
773 StrideDs[0],
774 a_element_op,
775 b_element_op,
776 c_element_op,
777 d_element_op,
778 reduce_in_element_ops,
779 reduce_out_element_ops};
780 }
781
782 static auto MakeInvoker() { return Invoker{}; }
783
784 // polymorphic
785 std::unique_ptr<BaseArgument>
786 MakeArgumentPointer(const void* p_a,
787 const void* p_b,
788 const void* p_bias,
789 std::array<const void*, 1> p_ds,
790 void* p_c,
791 std::array<void*, NumReduce> p_reduces,
792 ck::index_t M,
793 ck::index_t N,
794 ck::index_t K,
795 ck::index_t StrideA,
796 ck::index_t StrideB,
797 ck::index_t StrideC,
798 std::array<ck::index_t, 1> StrideDs,
799 std::array<void*, 3> gemm_element_ops,
800 std::array<void*, 1> d_element_ops,
801 std::array<void*, NumReduce> reduce_in_element_op,
802 std::array<void*, NumReduce> reduce_out_element_op,
803 index_t /* KBatch */ = 1) override
804 {
805 ReducePtrsGlobal reduce_tuple = generate_tuple(
806 [&](auto I) {
807 auto tmp = ReducePtrsGlobal{}[I];
808 using T = remove_pointer_t<decltype(tmp)>;
809 return static_cast<T*>(p_reduces[I]);
810 },
812
813 ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
814 [&](auto I) {
815 auto tmp = ReduceInElementwiseOperations{}[I];
816 using T = remove_pointer_t<decltype(tmp)>;
817 return *(static_cast<T*>(reduce_in_element_op[I]));
818 },
820 ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
821 [&](auto I) {
822 auto tmp = ReduceAccElementwiseOperations{}[I];
823 using T = remove_pointer_t<decltype(tmp)>;
824 return *(static_cast<T*>(reduce_out_element_op[I]));
825 },
827
828 AElementwiseOperation a_element_op =
829 *(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
830 BElementwiseOperation b_element_op =
831 *(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
832 CElementwiseOperation c_element_op =
833 *(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
834 D0ElementwiseOperation d_element_op =
835 *(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
836
837 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
838 static_cast<const BDataType*>(p_b),
839 static_cast<CDataType*>(p_c),
840 static_cast<const BiasDataType*>(p_bias),
841 static_cast<const D0DataType*>(p_ds[0]),
842 reduce_tuple,
843 M,
844 N,
845 K,
846 StrideA,
847 StrideB,
848 StrideC,
849 StrideDs[0],
850 a_element_op,
851 b_element_op,
852 c_element_op,
853 d_element_op,
854 reduce_in_element_ops,
855 reduce_out_element_ops);
856 }
857
858 // polymorphic
859 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
860 {
861 return std::make_unique<Invoker>(Invoker{});
862 }
863
864 // polymorphic
865 std::string GetTypeString() const override
866 {
867 auto str = std::stringstream();
868
869 // clang-format off
870 str << "DeviceGemmBiasAddReduce_Xdl_CShuffle"
871 << "<"
872 << BlockSize << ", "
873 << MPerBlock << ", "
874 << NPerBlock << ", "
875 << KPerBlock << ", "
876 << AK1 << ", "
877 << BK1 << ", "
878 << MPerXDL << ", "
879 << NPerXDL << ", "
880 << MXdlPerWave << ", "
881 << NXdlPerWave << ", "
882 << ABlockTransferSrcScalarPerVector << ", "
883 << BBlockTransferSrcScalarPerVector << ", "
884 << CShuffleMXdlPerWavePerShuffle << ", "
885 << CShuffleNXdlPerWavePerShuffle
886 << ">";
887 // clang-format on
888
889 return str.str();
890 }
891};
892
893} // namespace device
894} // namespace tensor_operation
895} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_bias_add_reduce_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC0 *__restrict__ p_bias_grid, const FloatC1 *__restrict__ p_d0_grid, ReducePtrsGlobal p_reduces_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const C1ElementwiseOperation c1_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c0_grid_desc_mblock_mperblock_nblock_nperblock, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:45
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:180
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::DefaultBlock2CTileMap
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:384
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::CheckValidity
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:279
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition device_base.hpp:197
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:455
CGridDesc_M_N c_grid_desc_m_n_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:506
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:505
CElementwiseOperation c_element_op_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:513
CDataType * p_c_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:500
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, const BiasDataType *p_bias_grid, const D0DataType *p_d0_grid, ReducePtrsGlobal p_reduces_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideC1, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, D0ElementwiseOperation d0_element_op, ReduceInElementwiseOperations reduce_in_element_ops, ReduceAccElementwiseOperations reduce_out_element_ops)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:456
const BiasDataType * p_bias_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:501
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:510
AElementwiseOperation a_element_op_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:511
ReducePtrsGlobal p_reduces_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:503
const D0DataType * p_d0_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:502
ReduceAccElementwiseOperations reduce_out_element_ops_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:516
D0ElementwiseOperation d0_element_op_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:514
BElementwiseOperation b_element_op_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:512
C0GridDesc_M_N c0_grid_desc_m_n_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:507
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:504
const ADataType * p_a_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:498
C1GridDesc_M_N c1_grid_desc_m_n_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:508
ReduceInElementwiseOperations reduce_in_element_ops_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:515
ReduceGridDesc_M reduce_grid_desc_m_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:509
const BDataType * p_b_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:499
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:521
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:525
DeviceOp::Argument Argument
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:522
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:662
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:81
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:195
GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemmBase
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:391
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:387
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:669
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:451
static constexpr int NumReduce
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:710
static constexpr auto I2
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:90
static auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:357
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:859
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) C1GridDesc_M_N
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:386
static constexpr auto I0
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:88
static auto MakeInvoker()
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:782
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:705
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:298
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:675
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:382
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:92
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:450
std::string GetTypeString() const override
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:865
decltype(MakeCGridDescriptor_M_N(1, 1, 0)) C0GridDesc_M_N
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:385
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:85
static constexpr auto NXdlPerWave32
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:86
DeviceGemmBiasAddReduce_Xdl_CShuffle DeviceOp
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:82
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 1 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 1 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 1 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op, index_t=1) override
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:786
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:383
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:384
static constexpr auto I1
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:89
static auto MakeArgument(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 1 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 1 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 1 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:711
Definition device_gemm_reduce.hpp:17