device_gemm_wmma_cshuffle_v3r1.hpp Source File

device_gemm_wmma_cshuffle_v3r1.hpp Source File#

Composable Kernel: device_gemm_wmma_cshuffle_v3r1.hpp Source File
device_gemm_wmma_cshuffle_v3r1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <sstream>
7#include <type_traits>
8#include <typeinfo>
9#include <memory>
10#include <array>
11#include <stdexcept>
12
14#include "ck/ck.hpp"
24
28
29namespace ck {
30namespace tensor_operation {
31namespace device {
32
33template <typename ALayout,
34 typename BLayout,
35 typename DsLayout,
36 typename CLayout,
37 typename ADataType,
38 typename BDataType,
39 typename DsDataType,
40 typename CDataType,
41 typename GemmAccDataType,
42 typename CShuffleDataType,
43 typename AElementwiseOperation,
44 typename BElementwiseOperation,
45 typename CElementwiseOperation,
46 GemmSpecialization GemmSpec,
47 index_t BlockSize,
48 index_t MPerBlock,
49 index_t NPerBlock,
50 index_t KPerBlock,
51 index_t AK1,
52 index_t BK1,
53 index_t MPerWmma,
54 index_t NPerWmma,
55 index_t MRepeat,
56 index_t NRepeat,
57 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
58 typename ABlockTransferThreadClusterArrangeOrder,
59 typename ABlockTransferSrcAccessOrder,
60 index_t ABlockTransferSrcVectorDim,
61 index_t ABlockTransferSrcScalarPerVector,
62 index_t ABlockTransferDstScalarPerVector_AK1,
63 bool ABlockLdsExtraM,
64 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
65 typename BBlockTransferThreadClusterArrangeOrder,
66 typename BBlockTransferSrcAccessOrder,
67 index_t BBlockTransferSrcVectorDim,
68 index_t BBlockTransferSrcScalarPerVector,
69 index_t BBlockTransferDstScalarPerVector_BK1,
70 bool BBlockLdsExtraN,
71 index_t CShuffleMRepeatPerShuffle,
72 index_t CShuffleNRepeatPerShuffle,
73 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
74 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
77 typename ReduceDataType = CDataType,
78 typename ComputeTypeA = CDataType,
79 typename ComputeTypeB = ComputeTypeA>
81 BLayout,
82 DsLayout,
83 CLayout,
84 ADataType,
85 BDataType,
86 DsDataType,
87 CDataType,
88 AElementwiseOperation,
89 BElementwiseOperation,
90 CElementwiseOperation>
91{
92 static constexpr index_t NumDTensor = DsDataType::Size();
93
95
97 ALayout,
98 BLayout,
99 Tuple<>,
100 CLayout,
103 GemmAccDataType,
104 ReduceDataType,
105 Tuple<>,
106 ReduceDataType,
107 AElementwiseOperation,
108 BElementwiseOperation,
110 GemmSpec,
111 BlockSize,
112 MPerBlock,
113 NPerBlock,
114 KPerBlock,
115 AK1,
116 BK1,
117 MPerWmma,
118 NPerWmma,
119 MRepeat,
120 NRepeat,
121 ABlockTransferThreadClusterLengths_AK0_M_AK1,
122 ABlockTransferThreadClusterArrangeOrder,
123 ABlockTransferSrcAccessOrder,
124 ABlockTransferSrcVectorDim,
125 ABlockTransferSrcScalarPerVector,
126 ABlockTransferDstScalarPerVector_AK1,
127 false,
128 ABlockLdsExtraM,
129 BBlockTransferThreadClusterLengths_BK0_N_BK1,
130 BBlockTransferThreadClusterArrangeOrder,
131 BBlockTransferSrcAccessOrder,
132 BBlockTransferSrcVectorDim,
133 BBlockTransferSrcScalarPerVector,
134 BBlockTransferDstScalarPerVector_BK1,
135 false,
136 BBlockLdsExtraN,
137 CShuffleMRepeatPerShuffle,
138 CShuffleNRepeatPerShuffle,
139 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
141 BlkGemmPipeSched,
142 BlkGemmPipelineVer,
143 ComputeTypeA,
144 ComputeTypeB,
145 false,
146 false>;
147
148 struct Argument : public GridwiseGemm::Argument
149 {
150 Argument(std::array<const void*, 1> p_a_grid_,
151 std::array<const void*, 1> p_b_grid_,
152 const ::std::array<const void*, NumDTensor> p_ds_,
153 CDataType* p_c_grid_,
154 index_t M_,
155 index_t N_,
156 index_t K_,
157 std::array<index_t, 1> StrideA_,
158 std::array<index_t, 1> StrideB_,
159 const ::std::array<index_t, NumDTensor> stride_ds_,
160 index_t StrideC_,
161 index_t KBatch_,
162 AElementwiseOperation a_element_op_,
163 BElementwiseOperation b_element_op_,
164 CElementwiseOperation c_element_op_)
165 : GridwiseGemm::Argument(p_a_grid_,
166 p_b_grid_,
167 ::std::array<const void*, 0>{},
168 reinterpret_cast<ReduceDataType*>(p_c_grid_),
169 M_,
170 N_,
171 K_,
172 StrideA_,
173 StrideB_,
174 std::array<index_t, 0>{},
175 StrideC_,
176 KBatch_,
177 a_element_op_,
178 b_element_op_,
179 PassThrough{},
180 true),
181 p_c_grid(p_c_grid_),
182 c_element_op(c_element_op_),
183 p_ds(p_ds_),
184 StrideDs(stride_ds_)
185 {
186 }
187
188 CDataType* p_c_grid;
189 CElementwiseOperation c_element_op;
190 const ::std::array<const void*, NumDTensor> p_ds;
191 ::std::array<index_t, NumDTensor> StrideDs;
192 };
193
195 using OutElementwiseOperation = CElementwiseOperation;
196
198 [](auto i) {
199 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
202 else
203 return Number<1>{};
204 },
206
208 ReduceDataType, // InDataType
209 DsDataType, // DsDatatype
210 GemmAccDataType, // AccDataType
211 CDataType, // OutDataType
212 3, // Rank
213 1, // NumReduceDim
214 ReduceAdd,
217 256, // BlockSize_
218 CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_
219 1, // KThreadSliceSize_
220 0, // InSrcVectorDim_
221 CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_
222 CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_
223 decltype(DsVectorLengthSequence)>;
224
225 struct Invoker : public BaseInvoker
226 {
227 float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
228 {
229 static constexpr index_t NumInDim = 3;
230 static constexpr index_t NumOutDim = 2;
231
232 ::std::array<index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
233 ::std::array<index_t, NumOutDim> out_lengths = {arg.M, arg.N};
234
235 ::std::array<index_t, NumInDim> in_strides;
236 ::std::array<index_t, NumOutDim> out_strides;
238 {
239 in_strides = {arg.M * arg.N, arg.N, 1};
240 out_strides = {arg.N, 1};
241 }
242 else
243 {
244 in_strides = {arg.M * arg.N, 1, arg.M};
245 out_strides = {1, arg.M};
246 }
247
248 ::std::array<int, 1> reduce_dims{0};
249
250 ::std::array<::std::array<index_t, NumOutDim>, NumDTensor> DsLengths;
251 ::std::array<::std::array<index_t, NumOutDim>, NumDTensor> DsStrides;
252
253 static_for<0, NumDTensor, 1>{}([&](auto i) {
254 DsLengths[i] = out_lengths;
255
256 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
258 {
259 DsStrides[i] = {arg.StrideDs[i], 1};
260 }
261 else
262 {
263 DsStrides[i] = {1, arg.StrideDs[i]};
264 }
265 });
266
267 auto reduce = DeviceReduceInstance{};
268
269 auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
270 in_strides,
271 DsLengths,
272 DsStrides,
273 out_lengths,
274 out_strides,
275 reduce_dims,
276 arg.p_workspace_,
277 arg.p_ds,
278 arg.p_c_grid,
279 PassThrough{},
281
282 auto invoker_ptr = reduce.MakeInvokerPointer();
283
284 float ave_time = 0;
285
286 if(reduce.IsSupportedArgument(argument_ptr.get()))
287 {
288 ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
289 }
290 else
291 {
292 throw ::std::runtime_error(
293 "The runtime parameters are not supported by the device instance.");
294 }
295
296 return ave_time;
297 }
298
299 float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{})
300 {
301 auto arg = *dynamic_cast<const typename GridwiseGemm::Argument*>(&arg_);
302
303 // workspace required when doing two-kernel reduce or Ds present
304 const bool need_workspace = !(!(arg.IsReduceAdd() || NumDTensor > 0) &&
306 if(need_workspace)
307 {
308 if(arg.p_workspace_ == nullptr)
309 {
310 throw ::std::runtime_error("using reduce, but empty workspace!");
311 }
312 arg.p_e_grid = reinterpret_cast<ReduceDataType*>(arg.p_workspace_);
313 }
314
315 if(stream_config.log_level_ > 0)
316 {
317 arg.Print();
318 }
319
321 {
322 throw ::std::runtime_error("wrong! GridwiseGemm has invalid setting");
323 }
324
325 index_t gdx, gdy, gdz;
326 ::std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
327
328 float ave_time = 0;
329
330 index_t k_grain = arg.KBatch * KPerBlock;
331 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
332
333 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
334
335 constexpr index_t minimum_occupancy =
336 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
337
338 if(has_main_k_block_loop)
339 {
340 const auto kernel =
342 true,
344 minimum_occupancy>;
345 ave_time = launch_and_time_kernel(
346 stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg);
347 }
348 else
349 {
350 const auto kernel =
352 false,
354 minimum_occupancy>;
355 ave_time = launch_and_time_kernel(
356 stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg);
357 }
358
359 if(need_workspace)
360 {
361 ave_time += RunReduce(arg_, stream_config);
362 }
363
364 return ave_time;
365 }
366
367 // polymorphic
368 float Run(const BaseArgument* p_arg,
369 const StreamConfig& stream_config = StreamConfig{}) override
370 {
371 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
372 }
373 };
374
375 static constexpr bool IsValidCompilationParameter()
376 {
377 // TODO: properly implement this
378 return true;
379 }
380
381 static bool IsSupportedArgument(const Argument& arg)
382 {
384 {
385 return false;
386 }
387
388 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
389 GemmSpec == GemmSpecialization::NKPadding ||
390 GemmSpec == GemmSpecialization::MNKPadding ||
391 GemmSpec == GemmSpecialization::KPadding))
392 {
393 return false;
394 }
395
397 *dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
398 }
399
400 bool IsSupportedArgument(const BaseArgument* p_arg) override
401 {
402 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
403 }
404
405 static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
406 {
407 return GridwiseGemm::CalculateGridSize(M, N, KBatch);
408 }
409
410 static constexpr index_t GetBlockSize() { return BlockSize; }
411
416
417 static auto MakeArgument(const ADataType* p_a,
418 const BDataType* p_b,
419 const ::std::array<const void*, NumDTensor> p_ds,
420 CDataType* p_c,
421 index_t M,
422 index_t N,
423 index_t K,
424 index_t StrideA,
425 index_t StrideB,
426 const ::std::array<index_t, NumDTensor> stride_ds,
427 index_t StrideC,
428 index_t KBatch,
429 AElementwiseOperation a_element_op,
430 BElementwiseOperation b_element_op,
431 CElementwiseOperation c_element_op)
432 {
433 return Argument{std::array<const void*, 1>{p_a},
434 std::array<const void*, 1>{p_b},
435 p_ds,
436 p_c,
437 M,
438 N,
439 K,
440 std::array<index_t, 1>{StrideA},
441 std::array<index_t, 1>{StrideB},
442 stride_ds,
443 StrideC,
444 KBatch,
445 a_element_op,
446 b_element_op,
447 c_element_op};
448 }
449
450 static auto MakeInvoker() { return Invoker{}; }
451
452 // polymorphic
453 ::std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
454 {
455 return ::std::make_unique<Invoker>(Invoker{});
456 }
457
458 // Polymorphic interfaces
459 ::std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
460 const void* p_b,
461 ::std::array<const void*, NumDTensor> p_ds,
462 void* p_c,
463 index_t M,
464 index_t N,
465 index_t K,
466 index_t StrideA,
467 index_t StrideB,
468 ::std::array<index_t, NumDTensor> DsStrides,
469 index_t StrideC,
470 index_t KSplit,
471 AElementwiseOperation a_element_op,
472 BElementwiseOperation b_element_op,
473 CElementwiseOperation c_element_op) override
474 {
475 return ::std::make_unique<Argument>(std::array<const void*, 1>{p_a},
476 std::array<const void*, 1>{p_b},
477 p_ds,
478 static_cast<CDataType*>(p_c),
479 M,
480 N,
481 K,
482 std::array<index_t, 1>{StrideA},
483 std::array<index_t, 1>{StrideB},
484 DsStrides,
485 StrideC,
486 KSplit,
487 a_element_op,
488 b_element_op,
489 c_element_op);
490 }
491
492 ::std::string GetTypeString() const override
493 {
494 auto str = ::std::stringstream();
495
496 auto BlkGemmPipelineSchedulerToString = [](BlockGemmPipelineScheduler s) {
497 switch(s)
498 {
499 case BlockGemmPipelineScheduler::Intrawave: return ::std::string("Intrawave");
500 case BlockGemmPipelineScheduler::Interwave: return ::std::string("Interwave");
501 }
502 return ::std::string("?");
503 };
504
505 auto BlkGemmPipelineVersionToString = [](BlockGemmPipelineVersion v) {
506 switch(v)
507 {
508 case BlockGemmPipelineVersion::v1: return ::std::string("v1");
509 case BlockGemmPipelineVersion::v2: return ::std::string("v2");
510 case BlockGemmPipelineVersion::v3: return ::std::string("v3");
511 case BlockGemmPipelineVersion::v4: return ::std::string("v4");
512 case BlockGemmPipelineVersion::v5: return ::std::string("v5");
513 }
514 return ::std::string("v?");
515 };
516
517 // clang-format off
518 str << "DeviceGemmWmmaUniversalReduce"
519 << "<"
520 << getGemmSpecializationString(GemmSpec) << ", "
521 << ::std::string(ALayout::name)[0]
522 << ::std::string(BLayout::name)[0]
523 << ::std::string(CLayout::name)[0]
524 << ">"
525 << " BlkSize: "
526 << BlockSize << ", "
527 << "BlkTile: "
528 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
529 << "WmmaTile: "
530 << MPerWmma<<"x"<<NPerWmma << ", "
531 << "WmmaRepeat: "
532 << MRepeat<<"x" << NRepeat<<", "
533 << "VmemReadVec: "
534 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
535 << "BlkGemmPipelineScheduler: "
536 << BlkGemmPipelineSchedulerToString(BlkGemmPipeSched) << ", "
537 << "BlkGemmPipelineVersion: "
538 << BlkGemmPipelineVersionToString(BlkGemmPipelineVer) << ", "
539 << "BlkGemmPipelinePrefetchStages: "
540 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
541 // clang-format on
542
543 return str.str();
544 }
545
546 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
547 {
548 auto arg = *dynamic_cast<const Argument*>(p_arg);
549
550 // Need workspace if using split-K or have D tensors
551 if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && is_same<CDataType, ReduceDataType>::value))
552 {
553 return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType);
554 }
555
556 return 0;
557 }
558};
559
560} // namespace device
561} // namespace tensor_operation
562} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
bool is_wmma_supported()
Definition host_utility/device_prop.hpp:127
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:40
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
STL namespace.
Definition ck/stream_config.hpp:10
static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:624
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:273
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:837
static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:852
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:233
Definition multi_index_transform.hpp:13
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition reduction_operator.hpp:37
Definition device_base.hpp:197
Definition device_gemm_wmma_cshuffle_v3r1.hpp:149
Argument(std::array< const void *, 1 > p_a_grid_, std::array< const void *, 1 > p_b_grid_, const ::std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, 1 > StrideA_, std::array< index_t, 1 > StrideB_, const ::std::array< index_t, NumDTensor > stride_ds_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition device_gemm_wmma_cshuffle_v3r1.hpp:150
CDataType * p_c_grid
Definition device_gemm_wmma_cshuffle_v3r1.hpp:188
const ::std::array< const void *, NumDTensor > p_ds
Definition device_gemm_wmma_cshuffle_v3r1.hpp:190
CElementwiseOperation c_element_op
Definition device_gemm_wmma_cshuffle_v3r1.hpp:189
::std::array< index_t, NumDTensor > StrideDs
Definition device_gemm_wmma_cshuffle_v3r1.hpp:191
Definition device_gemm_wmma_cshuffle_v3r1.hpp:226
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_wmma_cshuffle_v3r1.hpp:227
float Run(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_wmma_cshuffle_v3r1.hpp:299
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:368
Definition device_gemm_wmma_cshuffle_v3r1.hpp:91
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition device_gemm_wmma_cshuffle_v3r1.hpp:94
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_wmma_cshuffle_v3r1.hpp:375
static size_t GetSharedMemoryNumberOfByte()
Definition device_gemm_wmma_cshuffle_v3r1.hpp:412
static constexpr index_t NumDTensor
Definition device_gemm_wmma_cshuffle_v3r1.hpp:92
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, GemmAccDataType, ReduceDataType, Tuple<>, ReduceDataType, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, false, false > GridwiseGemm
Definition device_gemm_wmma_cshuffle_v3r1.hpp:96
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const ::std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, const ::std::array< index_t, NumDTensor > stride_ds, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_wmma_cshuffle_v3r1.hpp:417
static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition device_gemm_wmma_cshuffle_v3r1.hpp:405
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_wmma_cshuffle_v3r1.hpp:381
static constexpr auto DsVectorLengthSequence
Definition device_gemm_wmma_cshuffle_v3r1.hpp:197
static auto MakeInvoker()
Definition device_gemm_wmma_cshuffle_v3r1.hpp:450
::std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, ::std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, ::std::array< index_t, NumDTensor > DsStrides, index_t StrideC, index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:459
ck::reduce::Add ReduceAdd
Definition device_gemm_wmma_cshuffle_v3r1.hpp:194
CElementwiseOperation OutElementwiseOperation
Definition device_gemm_wmma_cshuffle_v3r1.hpp:195
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition device_gemm_wmma_cshuffle_v3r1.hpp:207
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:400
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:546
::std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:453
static constexpr index_t GetBlockSize()
Definition device_gemm_wmma_cshuffle_v3r1.hpp:410
::std::string GetTypeString() const override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:492
Definition device_gemm_v2.hpp:57
Definition device_reduce_threadwise_multi_d.hpp:47
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, const void *in_dev, const std::array< const void *, NumDTensor > ds_dev, void *out_dev, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op) override
Definition device_reduce_threadwise_multi_d.hpp:363
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340