device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp Source File

device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp Source File#

Composable Kernel: device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp Source File
device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <vector>
9
11#include "ck/utility/env.hpp"
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25// out[N, Ho, Wo, K] =
26// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K])
27template <
28 typename InDataType,
29 typename WeiDataType,
30 typename OutDataType,
31 typename AccDataType,
32 typename InElementwiseOperation,
33 typename WeiElementwiseOperation,
34 typename OutElementwiseOperation,
35 InMemoryDataOperationEnum OutGlobalMemoryDataOperation,
36 ConvolutionForwardSpecialization ConvForwardSpecialization,
37 ck::index_t BlockSize,
38 ck::index_t MPerBlock,
39 ck::index_t NPerBlock,
40 ck::index_t K0PerBlock,
41 ck::index_t K1,
42 ck::index_t MPerXDL,
43 ck::index_t NPerXDL,
44 ck::index_t MXdlPerWave,
45 ck::index_t NXdlPerWave,
46 typename ABlockTransferThreadClusterLengths_K0_M_K1,
47 typename ABlockTransferThreadClusterArrangeOrder,
48 typename ABlockTransferSrcAccessOrder,
49 ck::index_t ABlockTransferSrcVectorDim,
50 ck::index_t ABlockTransferSrcScalarPerVector,
51 ck::index_t ABlockTransferDstScalarPerVector_K1,
52 bool ABlockLdsAddExtraM,
53 typename BBlockTransferThreadClusterLengths_K0_N_K1,
54 typename BBlockTransferThreadClusterArrangeOrder,
55 typename BBlockTransferSrcAccessOrder,
56 ck::index_t BBlockTransferSrcVectorDim,
57 ck::index_t BBlockTransferSrcScalarPerVector,
58 ck::index_t BBlockTransferDstScalarPerVector_K1,
59 bool BBlockLdsAddExtraN,
60 index_t CShuffleMXdlPerWavePerShuffle,
61 index_t CShuffleNXdlPerWavePerShuffle,
62 typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
63 index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
65 : public DeviceConvFwdBiasActivation<InElementwiseOperation,
66 WeiElementwiseOperation,
67 OutElementwiseOperation>
68{
69 using DeviceOp =
71
73 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
74 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
75
76 using ADataType = InDataType;
77 using BDataType = WeiDataType;
78 using CDataType = OutDataType;
79
80 // TODO make A/B datatype different
81 using ABDataType = InDataType;
82
83 // TODO make it support any # of spatial dimensions
84 static constexpr index_t NDimSpatial = 2;
85
86 static constexpr auto I0 = Number<0>{};
87 static constexpr auto I1 = Number<1>{};
88 static constexpr auto I2 = Number<2>{};
89 static constexpr auto I3 = Number<3>{};
90
91 static constexpr auto K1Number = Number<K1>{};
92 static constexpr auto GemmK1Number = K1Number;
93
94 static auto
98 std::vector<ck::index_t> input_spatial_lengths,
99 std::vector<ck::index_t> filter_spatial_lengths,
100 std::vector<ck::index_t> output_spatial_lengths,
101 std::vector<ck::index_t> conv_filter_strides,
102 std::vector<ck::index_t> conv_filter_dilations,
103 std::vector<ck::index_t> input_left_pads,
104 std::vector<ck::index_t> input_right_pads)
105 {
106 using namespace ck;
107
108 const index_t Hi = input_spatial_lengths[0];
109 const index_t Wi = input_spatial_lengths[1];
110
111 const index_t Ho = output_spatial_lengths[0];
112 const index_t Wo = output_spatial_lengths[1];
113
114 const index_t Y = filter_spatial_lengths[0];
115 const index_t X = filter_spatial_lengths[1];
116
117 const index_t ConvStrideH = conv_filter_strides[0];
118 const index_t ConvStrideW = conv_filter_strides[1];
119
120 const index_t ConvDilationH = conv_filter_dilations[0];
121 const index_t ConvDilationW = conv_filter_dilations[1];
122
123 const index_t InLeftPadH = input_left_pads[0];
124 const index_t InLeftPadW = input_left_pads[1];
125
126 const index_t InRightPadH = input_right_pads[0];
127 const index_t InRightPadW = input_right_pads[1];
128
129 const index_t GemmMRaw = N * Ho * Wo;
130 const index_t GemmN = K;
131
132 const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
133 const auto GemmMPad = GemmM - GemmMRaw;
134
135 if constexpr(ConvForwardSpecialization ==
137 { // 1x1, stride=1, pad=0
138 const index_t GemmK = Y * X * C;
139 assert(GemmK % GemmK1Number == 0);
140
141 const index_t GemmK0 = GemmK / GemmK1Number;
142
143 // A: input tensor
144 const auto in_gemmmraw_gemmk_grid_desc =
146
147 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
148 in_gemmmraw_gemmk_grid_desc,
150 make_right_pad_transform(GemmMRaw, GemmMPad)),
153
154 // B: weight tensor
155 const auto wei_gemmn_gemmk_grid_desc =
157
158 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
159 wei_gemmn_gemmk_grid_desc,
164
165 // C: output tensor
166 const auto out_gemmmraw_gemmn_grid_desc =
168
169 const auto out_gemmm_gemmn_grid_desc =
170 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
171 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
175
176 // C0: bias tensor: assume a contiguous vector
177 const auto bias_grid_desc_gemmm_gemmn =
179
180 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
181 wei_gemmk0_gemmn_gemmk1_grid_desc,
182 out_gemmm_gemmn_grid_desc,
183 bias_grid_desc_gemmm_gemmn);
184 }
185 else if constexpr(ConvForwardSpecialization ==
187 { // 1x1, pad=0
188 const index_t GemmK = Y * X * C;
189 assert(GemmK % GemmK1Number == 0);
190
191 const index_t GemmK0 = GemmK / GemmK1Number;
192
193 // A: input tensor
194 const auto in_n_hi_wi_c_grid_desc =
196
197 const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
198 in_n_hi_wi_c_grid_desc,
200 make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
201 make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
205
206 const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
207 in_n_ho_wo_c_grid_desc,
209 make_merge_transform(make_tuple(N, Ho, Wo))),
212
213 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
214 in_gemmk0_gemmmraw_gemmk1_grid_desc,
216 make_right_pad_transform(GemmMRaw, GemmMPad),
220
221 // B: weight tensor
222 const auto wei_gemmn_gemmk_grid_desc =
224
225 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
226 wei_gemmn_gemmk_grid_desc,
231
232 // C: output tensor
233 const auto out_gemmmraw_gemmn_grid_desc =
235
236 const auto out_gemmm_gemmn_grid_desc =
237 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
238 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
242
243 // C0: bias tensor: assume a contiguous vector
244 const auto bias_grid_desc_gemmm_gemmn =
246
247 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
248 wei_gemmk0_gemmn_gemmk1_grid_desc,
249 out_gemmm_gemmn_grid_desc,
250 bias_grid_desc_gemmm_gemmn);
251 }
252 else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
253 { // C = odd value
254 const index_t GemmKRaw = Y * X * C;
255 const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
256 const index_t GemmKPad = GemmK - GemmKRaw;
257 const index_t GemmK0 = GemmK / GemmK1Number;
258
259 // A: input tensor
260 const auto in_n_hi_wi_c_grid_desc =
262
263 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
264 in_n_hi_wi_c_grid_desc,
266 make_pad_transform(Hi, InLeftPadH, InRightPadH),
267 make_pad_transform(Wi, InLeftPadW, InRightPadW),
271
272 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
273 in_n_hip_wip_c_grid_desc,
276 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
277 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
281
282 const auto in_gemmkraw_gemmmraw_grid_desc =
283 transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
285 make_merge_transform(make_tuple(N, Ho, Wo))),
288
289 const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
290 in_gemmkraw_gemmmraw_grid_desc,
291 make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad),
292 make_right_pad_transform(GemmMRaw, GemmMPad)),
295
296 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
297 in_gemmk_gemmm_grid_desc,
302
303 // B: weight tensor
304 const auto wei_k_yxc_grid_desc =
306
307 const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
308 wei_k_yxc_grid_desc,
310 make_right_pad_transform(GemmKRaw, GemmKPad)),
313
314 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
315 wei_gemmk_gemmn_grid_desc,
320
321 // C: output tensor
322 const auto out_nhowo_k_grid_desc =
324
325 const auto out_gemmmraw_gemmn_grid_desc =
326 transform_tensor_descriptor(out_nhowo_k_grid_desc,
331
332 const auto out_gemmm_gemmn_grid_desc =
333 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
334 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
338
339 // C0: bias tensor: assume a contiguous vector
340 const auto bias_grid_desc_gemmm_gemmn =
342
343 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
344 wei_gemmk0_gemmn_gemmk1_grid_desc,
345 out_gemmm_gemmn_grid_desc,
346 bias_grid_desc_gemmm_gemmn);
347 }
348 else
349 {
350 const index_t GemmK = Y * X * C;
351 assert(GemmK % GemmK1Number == 0);
352
353 const index_t GemmK0 = GemmK / GemmK1Number;
354
355 // A: input tensor
356 const auto in_n_hi_wi_c_grid_desc =
358
359 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
360 in_n_hi_wi_c_grid_desc,
362 make_pad_transform(Hi, InLeftPadH, InRightPadH),
363 make_pad_transform(Wi, InLeftPadW, InRightPadW),
367
368 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
369 in_n_hip_wip_c_grid_desc,
372 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
373 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
377
378 const auto in_gemmk_gemmmraw_grid_desc =
379 transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
381 make_merge_transform(make_tuple(N, Ho, Wo))),
384
385 const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
386 in_gemmk_gemmmraw_grid_desc,
391
392 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
393 in_gemmk0_gemmmraw_gemmk1_grid_desc,
395 make_right_pad_transform(GemmMRaw, GemmMPad),
399
400 // B: weight tensor
401 const auto wei_k_yxc_grid_desc =
403
404 const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
405 wei_k_yxc_grid_desc,
409
410 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
411 wei_gemmk_gemmn_grid_desc,
416
417 // C: output tensor
418 const auto out_nhowo_k_grid_desc =
420
421 const auto out_gemmmraw_gemmn_grid_desc =
422 transform_tensor_descriptor(out_nhowo_k_grid_desc,
427
428 const auto out_gemmm_gemmn_grid_desc =
429 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
430 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
434
435 // C0: bias tensor: assume a contiguous vector
436 const auto bias_grid_desc_gemmm_gemmn =
438
439 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
440 wei_gemmk0_gemmn_gemmk1_grid_desc,
441 out_gemmm_gemmn_grid_desc,
442 bias_grid_desc_gemmm_gemmn);
443 }
444 }
445
447 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
448
453
454 // GridwiseGemm
455 template <index_t NXdlPerWave_>
457 BlockSize,
458 ABDataType, // TODO: distinguish A/B datatype
459 AccDataType,
460 CDataType,
461 OutGlobalMemoryDataOperation,
466 InElementwiseOperation,
467 WeiElementwiseOperation,
468 OutElementwiseOperation,
469 MPerBlock,
470 NPerBlock,
471 K0PerBlock,
472 MPerXDL,
473 NPerXDL,
474 K1,
475 MXdlPerWave,
476 NXdlPerWave_,
477 ABlockTransferThreadClusterLengths_K0_M_K1,
478 Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
479 Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
480 2, // ABlockTransferSrcVectorDim,
481 ABlockTransferSrcScalarPerVector,
482 ABlockTransferDstScalarPerVector_K1,
483 false, // AThreadTransferSrcResetCoordinateAfterRun,
484 ABlockLdsAddExtraM,
485 BBlockTransferThreadClusterLengths_K0_N_K1,
486 Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
487 Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
488 2, // BBlockTransferSrcVectorDim,
489 BBlockTransferSrcScalarPerVector,
490 BBlockTransferDstScalarPerVector_K1,
491 false, // BThreadTransferSrcResetCoordinateAfterRun,
492 BBlockLdsAddExtraN,
493 CShuffleMXdlPerWavePerShuffle,
494 CShuffleNXdlPerWavePerShuffle,
495 CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
496 CBlockTransferScalarPerVector_NWaveNPerXdl>;
499
500 // Argument
501 struct Argument : public BaseArgument
502 {
503 Argument(const InDataType* p_in_grid,
504 const WeiDataType* p_wei_grid,
505 OutDataType* p_out_grid,
506 const OutDataType* p_bias_grid,
507 ck::index_t N,
508 ck::index_t K,
509 ck::index_t C,
510 std::vector<ck::index_t> input_spatial_lengths,
511 std::vector<ck::index_t> filter_spatial_lengths,
512 std::vector<ck::index_t> output_spatial_lengths,
513 std::vector<ck::index_t> conv_filter_strides,
514 std::vector<ck::index_t> conv_filter_dilations,
515 std::vector<ck::index_t> input_left_pads,
516 std::vector<ck::index_t> input_right_pads,
517 ck::index_t M01,
518 ck::index_t N01,
519 InElementwiseOperation in_element_op,
520 WeiElementwiseOperation wei_element_op,
521 OutElementwiseOperation out_element_op)
522 : p_a_grid_{p_in_grid},
523 p_b_grid_{p_wei_grid},
524 p_c_grid_{p_out_grid},
525 p_c0_grid_{p_bias_grid},
531 M01_{M01},
532 N01_{N01},
533 in_element_op_{in_element_op},
534 wei_element_op_{wei_element_op},
535 out_element_op_{out_element_op},
536 Conv_N_{N},
537 Conv_K_{K},
538 Conv_C_{C},
539 input_spatial_lengths_{input_spatial_lengths},
540 filter_spatial_lengths_{filter_spatial_lengths},
541 output_spatial_lengths_{output_spatial_lengths},
542 conv_filter_strides_{conv_filter_strides},
543 conv_filter_dilations_{conv_filter_dilations},
544 input_left_pads_{input_left_pads},
545 input_right_pads_{input_right_pads}
546 {
547 const auto descs =
549 K,
550 C,
551 input_spatial_lengths,
552 filter_spatial_lengths,
553 output_spatial_lengths,
554 conv_filter_strides,
555 conv_filter_dilations,
556 input_left_pads,
557 input_right_pads);
558
559 a_grid_desc_k0_m_k1_ = descs[I0];
560 b_grid_desc_k0_n_k1_ = descs[I1];
561 c_grid_desc_m_n_ = descs[I2];
562 c0_grid_desc_m_n_ = descs[I3];
565 }
566
578 InElementwiseOperation in_element_op_;
579 WeiElementwiseOperation wei_element_op_;
580 OutElementwiseOperation out_element_op_;
581 // for checking IsSupportedArgument()
585 std::vector<index_t> input_spatial_lengths_;
586 std::vector<index_t> filter_spatial_lengths_;
587 std::vector<index_t> output_spatial_lengths_;
588 std::vector<index_t> conv_filter_strides_;
589 std::vector<index_t> conv_filter_dilations_;
590 std::vector<index_t> input_left_pads_;
591 std::vector<index_t> input_right_pads_;
592 };
593
594 // Invoker
595 struct Invoker : public BaseInvoker
596 {
598
599 template <typename GridwiseGemm>
600 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
601 {
602 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
603 {
604 std::cout << DeviceOp{}.GetTypeString() << std::endl;
605 std::cout << "N " << arg.Conv_N_ << ", " << "K " << arg.Conv_K_ << ", " << "C "
606 << arg.Conv_C_ << ", " << std::endl;
607 std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
608 << arg.filter_spatial_lengths_[1] << ", " << std::endl;
609 std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
610 << arg.input_spatial_lengths_[1] << ", " << std::endl;
611 std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
612 << arg.output_spatial_lengths_[1] << ", " << std::endl;
613 std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
614 << arg.conv_filter_strides_[1] << ", " << std::endl;
615 std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
616 << arg.conv_filter_dilations_[1] << ", " << std::endl;
617 std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
618 << arg.input_left_pads_[1] << ", " << std::endl;
619 std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
620 << arg.input_right_pads_[1] << ", " << std::endl;
621
622 std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
623 << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
624 << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
625
626 std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
627 << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
628 << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
629
630 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
631 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
632
633 std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
634 << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
635 }
636
637 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
641 {
642 throw std::runtime_error(
643 "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting");
644 }
645
646 const index_t grid_size =
647 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
648
649 const auto K =
650 arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
651
652 float ave_time = 0;
653
654 auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
655 GridwiseGemm::
656 MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
657 arg.c_grid_desc_m_n_);
658
659 auto c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
660 GridwiseGemm::
661 MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
663 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
664 {
665 const auto kernel = kernel_gemm_xdlops_v3r2<
666 GridwiseGemm,
667 ADataType, // TODO: distiguish A/B datatype
668 CDataType,
672 typename GridwiseGemm::
673 CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
675 typename GridwiseGemm::
676 C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
677 InElementwiseOperation,
678 WeiElementwiseOperation,
679 OutElementwiseOperation,
681 true>;
682
683 ave_time = launch_and_time_kernel(
684 stream_config,
685 kernel,
686 dim3(grid_size),
687 dim3(BlockSize),
688 0,
689 arg.p_a_grid_,
690 arg.p_b_grid_,
691 arg.p_c_grid_,
692 arg.p_c0_grid_,
695 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
696 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
697 arg.in_element_op_,
698 arg.wei_element_op_,
699 arg.out_element_op_,
701 }
702 else
703 {
704 const auto kernel = kernel_gemm_xdlops_v3r2<
705 GridwiseGemm,
706 ADataType, // TODO: distiguish A/B datatype
707 CDataType,
711 typename GridwiseGemm::
712 CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
714 typename GridwiseGemm::
715 C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
716 InElementwiseOperation,
717 WeiElementwiseOperation,
718 OutElementwiseOperation,
720 false>;
721
722 ave_time = launch_and_time_kernel(
723 stream_config,
724 kernel,
725 dim3(grid_size),
726 dim3(BlockSize),
727 0,
728 arg.p_a_grid_,
729 arg.p_b_grid_,
730 arg.p_c_grid_,
731 arg.p_c0_grid_,
734 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
735 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
736 arg.in_element_op_,
737 arg.wei_element_op_,
738 arg.out_element_op_,
740 }
741
742 return ave_time;
743 }
744
746
747 float Run(const BaseArgument* p_arg,
748 const StreamConfig& stream_config = StreamConfig{}) override
749 {
750 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
751 }
752 };
753
754 static constexpr bool IsValidCompilationParameter()
755 {
756 // TODO: properly implement this check
757 return true;
758 }
759
760 static bool IsSupportedArgument(const Argument& arg)
761 {
763 {
764 return false;
765 }
766 if constexpr(ConvForwardSpecialization ==
768 {
769 // check if it's 1x1, stride=1 conv
770 if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
771 arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
772 arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
773 arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
774 {
775 return false;
776 }
777 }
778 else if constexpr(ConvForwardSpecialization ==
780 {
781 // check if it's 1x1 conv
782 if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
783 arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
784 arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
785 {
786 return false;
787 }
788 }
789
790 // vector load A/B matrix from global memory
791 if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
792 arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
793 arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
794 {
795 return false;
796 }
797
798 // vector store C matrix into global memory
799 if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
800 {
801 return false;
802 }
803
804 // Gridwise GEMM size
805 if(get_warp_size() == 64)
806 {
807 if constexpr(NXdlPerWave64 > 0)
808 {
813 }
814 }
815 else
816 {
817 if constexpr(NXdlPerWave32 > 0)
818 {
823 }
824 }
825 return false;
826 }
827
828 bool IsSupportedArgument(const BaseArgument* p_arg) override
829 {
830 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
831 }
832
833 static auto MakeArgument(const InDataType* p_in_grid,
834 const WeiDataType* p_wei_grid,
835 OutDataType* p_out_grid,
836 const OutDataType* p_bias_grid,
837 ck::index_t N,
838 ck::index_t K,
839 ck::index_t C,
840 std::vector<ck::index_t> input_spatial_lengths,
841 std::vector<ck::index_t> filter_spatial_lengths,
842 std::vector<ck::index_t> output_spatial_lengths,
843 std::vector<ck::index_t> conv_filter_strides,
844 std::vector<ck::index_t> conv_filter_dilations,
845 std::vector<ck::index_t> input_left_pads,
846 std::vector<ck::index_t> input_right_pads,
847 InElementwiseOperation in_element_op,
848 WeiElementwiseOperation wei_element_op,
849 OutElementwiseOperation out_element_op)
850 {
851 return Argument{p_in_grid,
852 p_wei_grid,
853 p_out_grid,
854 p_bias_grid,
855 N,
856 K,
857 C,
858 input_spatial_lengths,
859 filter_spatial_lengths,
860 output_spatial_lengths,
861 conv_filter_strides,
862 conv_filter_dilations,
863 input_left_pads,
864 input_right_pads,
865 1,
866 1,
867 in_element_op,
868 wei_element_op,
869 out_element_op};
870 }
871
872 static auto MakeInvoker() { return Invoker{}; }
873
874 std::unique_ptr<BaseArgument>
875 MakeArgumentPointer(const void* p_in_grid,
876 const void* p_wei_grid,
877 void* p_out_grid,
878 const void* p_bias_grid,
879 ck::index_t N,
880 ck::index_t K,
881 ck::index_t C,
882 std::vector<ck::index_t> input_spatial_lengths,
883 std::vector<ck::index_t> filter_spatial_lengths,
884 std::vector<ck::index_t> output_spatial_lengths,
885 std::vector<ck::index_t> conv_filter_strides,
886 std::vector<ck::index_t> conv_filter_dilations,
887 std::vector<ck::index_t> input_left_pads,
888 std::vector<ck::index_t> input_right_pads,
889 InElementwiseOperation in_element_op,
890 WeiElementwiseOperation wei_element_op,
891 OutElementwiseOperation out_element_op) override
892 {
893 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
894 static_cast<const WeiDataType*>(p_wei_grid),
895 static_cast<OutDataType*>(p_out_grid),
896 static_cast<const OutDataType*>(p_bias_grid),
897 N,
898 K,
899 C,
900 input_spatial_lengths,
901 filter_spatial_lengths,
902 output_spatial_lengths,
903 conv_filter_strides,
904 conv_filter_dilations,
905 input_left_pads,
906 input_right_pads,
907 1,
908 1,
909 in_element_op,
910 wei_element_op,
911 out_element_op);
912 }
913
914 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
915 {
916 return std::make_unique<Invoker>(Invoker{});
917 }
918
919 std::string GetTypeString() const override
920 {
921 auto str = std::stringstream();
922
923 // clang-format off
924 str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
925 << "<"
926 << BlockSize << ", "
927 << MPerBlock << ", "
928 << NPerBlock << ", "
929 << K0PerBlock << ", "
930 << K1 << ", "
931 << MPerXDL << ", "
932 << NPerXDL << ", "
933 << MXdlPerWave << ", "
934 << NXdlPerWave << ", "
935 << ABlockTransferSrcScalarPerVector << ", "
936 << ABlockTransferDstScalarPerVector_K1 << ", "
937 << BBlockTransferSrcScalarPerVector << ", "
938 << BBlockTransferDstScalarPerVector_K1 << ", "
939 << CShuffleMXdlPerWavePerShuffle << ", "
940 << CShuffleNXdlPerWavePerShuffle << ", "
941 << CBlockTransferScalarPerVector_NWaveNPerXdl
942 << ">";
943 // clang-format on
944
945 return str.str();
946 }
947};
948} // namespace device
949} // namespace tensor_operation
950} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ OddC
Definition convolution_forward_specialization.hpp:19
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_gemm_xdlops_v3r2(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_xdlops_v3r2.hpp:36
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v3r2.hpp:133
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
DeviceOp::Argument Argument
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:597
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:747
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:600
index_t Conv_C_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:584
const BDataType * p_b_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:568
index_t M01_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:576
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:571
index_t N01_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:577
std::vector< index_t > input_right_pads_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:591
std::vector< index_t > conv_filter_dilations_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:589
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:575
C0GridDesc_M_N c0_grid_desc_m_n_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:574
const CDataType * p_c0_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:570
std::vector< index_t > filter_spatial_lengths_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:586
WeiElementwiseOperation wei_element_op_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:579
OutElementwiseOperation out_element_op_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:580
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:572
std::vector< index_t > input_left_pads_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:590
Argument(const InDataType *p_in_grid, const WeiDataType *p_wei_grid, OutDataType *p_out_grid, const OutDataType *p_bias_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, ck::index_t M01, ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:503
std::vector< index_t > output_spatial_lengths_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:587
index_t Conv_N_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:582
InElementwiseOperation in_element_op_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:578
std::vector< index_t > conv_filter_strides_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:588
index_t Conv_K_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:583
CGridDesc_M_N c_grid_desc_m_n_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:573
CDataType * p_c_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:569
std::vector< index_t > input_spatial_lengths_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:585
const ADataType * p_a_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:567
std::string GetTypeString() const override
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:919
static auto MakeArgument(const InDataType *p_in_grid, const WeiDataType *p_wei_grid, OutDataType *p_out_grid, const OutDataType *p_bias_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:833
static constexpr auto K1Number
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:91
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< BlockSize, ABDataType, AccDataType, CDataType, OutGlobalMemoryDataOperation, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, C0GridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl > GridwiseGemmBase
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:456
static constexpr auto I2
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:88
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:95
WeiDataType BDataType
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:77
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:498
static constexpr index_t NDimSpatial
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:84
static constexpr auto I0
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:86
static auto MakeInvoker()
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:872
static constexpr auto GemmK1Number
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:92
static constexpr auto I1
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:87
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceOp
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:69
static constexpr auto NXdlPerWave32
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:74
InDataType ADataType
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:76
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, const void *p_wei_grid, void *p_out_grid, const void *p_bias_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) override
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:875
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:828
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:497
remove_cvref_t< decltype(ABCGridDescs{}[I3])> C0GridDesc_M_N
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:452
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:449
static constexpr auto I3
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:89
static constexpr bool IsValidCompilationParameter()
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:754
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:451
decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})) ABCGridDescs
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:446
OutDataType CDataType
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:78
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:73
static bool IsSupportedArgument(const Argument &arg)
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:760
InDataType ABDataType
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:81
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:450
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp:914
Definition device_conv_fwd_bias_activation.hpp:20
#define CK_ENV(name)
Definition utility/env.hpp:129