grouped_convolution_forward_kernel.hpp Source File

grouped_convolution_forward_kernel.hpp Source File#

Composable Kernel: grouped_convolution_forward_kernel.hpp Source File
grouped_convolution_forward_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <string>
8
9#include "ck_tile/core.hpp"
18
19namespace ck_tile {
20
22template <typename GroupedConvTraitsType_, typename CDElementwise_>
24{
25
27 TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
28 GroupedConvTraitsType_::ConvSpecialization,
29 GroupedConvTraitsType_::VectorSizeA,
30 GroupedConvTraitsType_::VectorSizeB,
31 GroupedConvTraitsType_::VectorSizeC,
32 GroupedConvTraitsType_::NumGroupsToMerge,
33 true>; // Split N enabled
34 using CDElementwise = CDElementwise_;
35 static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
36
37 template <
38 typename InLay = typename GroupedConvTraitsType_::InLayout,
39 typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
40 typename OutLay = typename GroupedConvTraitsType_::OutLayout,
41 typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
42 std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
43 std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
44 bool>::type = false>
46 : elfunc(args.elfunc)
47 {
48 in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
49 static_cast<index_t>(args.N_),
50 static_cast<index_t>(args.C_),
51 static_cast<index_t>(args.input_spatial_lengths_[0])};
52 wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
53 static_cast<index_t>(args.K_),
54 static_cast<index_t>(args.C_),
55 static_cast<index_t>(args.filter_spatial_lengths_[0])};
56 out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
57 static_cast<index_t>(args.N_),
58 static_cast<index_t>(args.K_),
59 static_cast<index_t>(args.output_spatial_lengths_[0])};
60
61 conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
63 input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
64 input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
65
66 k_batch = args.k_batch;
67
68 // GemmM will be set after Split-N calculation
69 GemmN = args.K_;
70 GemmK = args.C_ * args.filter_spatial_lengths_[0];
71 GemmBatch = args.G_;
72
73 in_ptr = args.in_ptr;
74 wei_ptr = args.wei_ptr;
75 for(index_t d = 0; d < NumDTensor; d++)
76 {
77 ds_ptr[d] = args.ds_ptr[d];
78 }
79 out_ptr = args.out_ptr;
80
81 // Create and STORE transformer (for split-image support)
89
91 transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
93 transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
95 transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
96
97 group_stride_a = args.C_;
98 group_stride_b = args.K_ * args.C_ *
99 std::accumulate(args.filter_spatial_lengths_.begin(),
100 args.filter_spatial_lengths_.end(),
101 1,
102 std::multiplies<index_t>());
103 group_stride_c = args.K_;
104
105 // Initialize Split-N support fields for 1D convolution (NWGC layout)
106 // Get the actual split N from transformer
107 n_per_split = transformer_.GetN();
108 original_n = transformer_.GetOriginalN();
110
111 // Calculate batch strides using the original argument dimensions.
112 // These are the original dimensions passed to the constructor, not modified by the invoker
113 // yet. (The invoker modifies args after calling MakeKernelArgs.) VERIFIED: G_ MUST be
114 // included - NWGC layout has all groups within each batch
115 input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0];
116 output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0];
117
118 // Update GemmM to use split N (not original N)
120 }
121
122 template <
123 typename InLay = typename GroupedConvTraitsType_::InLayout,
124 typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
125 typename OutLay = typename GroupedConvTraitsType_::OutLayout,
126 typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
127 std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
128 std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
129 bool>::type = false>
131 : elfunc(args.elfunc)
132 {
133 in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
134 static_cast<index_t>(args.N_),
135 static_cast<index_t>(args.C_),
136 static_cast<index_t>(args.input_spatial_lengths_[0]),
137 static_cast<index_t>(args.input_spatial_lengths_[1])};
138 wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
139 static_cast<index_t>(args.K_),
140 static_cast<index_t>(args.C_),
141 static_cast<index_t>(args.filter_spatial_lengths_[0]),
142 static_cast<index_t>(args.filter_spatial_lengths_[1])};
143 out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
144 static_cast<index_t>(args.N_),
145 static_cast<index_t>(args.K_),
146 static_cast<index_t>(args.output_spatial_lengths_[0]),
147 static_cast<index_t>(args.output_spatial_lengths_[1])};
148
149 conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
150 static_cast<index_t>(args.conv_filter_strides_[1])};
152 static_cast<index_t>(args.conv_filter_dilations_[1])};
153 input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
154 static_cast<index_t>(args.input_left_pads_[1])};
155 input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
156 static_cast<index_t>(args.input_right_pads_[1])};
157
158 k_batch = args.k_batch;
159
160 // Note: GemmM will be set after Split-N calculation
161 GemmN = args.K_;
162 GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
163 GemmBatch = args.G_;
164
165 in_ptr = args.in_ptr;
166 wei_ptr = args.wei_ptr;
167 for(index_t d = 0; d < NumDTensor; d++)
168 {
169 ds_ptr[d] = args.ds_ptr[d];
170 }
171 out_ptr = args.out_ptr;
172
173 // Create and STORE transformer (for split-image support)
181
183 transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
185 transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
187 transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
188
189 group_stride_a = args.C_;
190 group_stride_b = args.K_ * args.C_ *
191 std::accumulate(args.filter_spatial_lengths_.begin(),
192 args.filter_spatial_lengths_.end(),
193 1,
194 std::multiplies<index_t>());
195 group_stride_c = args.K_;
196
197 // Initialize Split-N support fields for 2D convolution (NHWGC layout)
198 // Get the actual split N from transformer
199 n_per_split = transformer_.GetN();
200 original_n = transformer_.GetOriginalN();
202
203 // Calculate batch strides for NHWGC layout
204 // VERIFIED: G_ MUST be included - NHWGC layout has all groups within each batch
206 args.G_ * args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
208 args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
209
210 // Update GemmM to use split N (not original N)
212 }
213
214 template <
215 typename InLay = typename GroupedConvTraitsType_::InLayout,
216 typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
217 typename OutLay = typename GroupedConvTraitsType_::OutLayout,
218 typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
219 std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
220 std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
221 bool>::type = false>
223 : elfunc(args.elfunc)
224 {
225 in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
226 static_cast<index_t>(args.N_),
227 static_cast<index_t>(args.C_),
228 static_cast<index_t>(args.input_spatial_lengths_[0]),
229 static_cast<index_t>(args.input_spatial_lengths_[1]),
230 static_cast<index_t>(args.input_spatial_lengths_[2])};
231 wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
232 static_cast<index_t>(args.K_),
233 static_cast<index_t>(args.C_),
234 static_cast<index_t>(args.filter_spatial_lengths_[0]),
235 static_cast<index_t>(args.filter_spatial_lengths_[1]),
236 static_cast<index_t>(args.filter_spatial_lengths_[2])};
237 out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
238 static_cast<index_t>(args.N_),
239 static_cast<index_t>(args.K_),
240 static_cast<index_t>(args.output_spatial_lengths_[0]),
241 static_cast<index_t>(args.output_spatial_lengths_[1]),
242 static_cast<index_t>(args.output_spatial_lengths_[2])};
243
244 conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
245 static_cast<index_t>(args.conv_filter_strides_[1]),
246 static_cast<index_t>(args.conv_filter_strides_[2])};
248 static_cast<index_t>(args.conv_filter_dilations_[1]),
249 static_cast<index_t>(args.conv_filter_dilations_[2])};
250 input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
251 static_cast<index_t>(args.input_left_pads_[1]),
252 static_cast<index_t>(args.input_left_pads_[2])};
253 input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
254 static_cast<index_t>(args.input_right_pads_[1]),
255 static_cast<index_t>(args.input_right_pads_[2])};
256
257 k_batch = args.k_batch;
258
259 // Note: GemmM will be set after Split-N calculation
260 GemmN = args.K_;
261 GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
263 GemmBatch = args.G_;
264
265 in_ptr = args.in_ptr;
266 wei_ptr = args.wei_ptr;
267 for(index_t d = 0; d < NumDTensor; d++)
268 {
269 ds_ptr[d] = args.ds_ptr[d];
270 }
271 out_ptr = args.out_ptr;
272
273 // Create and STORE transformer (for split-image support)
281
283 transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
285 transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
287 transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
288
289 group_stride_a = args.C_;
290 group_stride_b = args.K_ * args.C_ *
291 std::accumulate(args.filter_spatial_lengths_.begin(),
292 args.filter_spatial_lengths_.end(),
293 1,
294 std::multiplies<index_t>());
295 group_stride_c = args.K_;
296
297 // Initialize Split-N support fields for 3D convolution (NDHWGC layout)
298 // Get the actual split N from transformer
299 n_per_split = transformer_.GetN();
300 original_n = transformer_.GetOriginalN();
302
303 // Calculate batch strides for NDHWGC layout
304 // VERIFIED: G_ MUST be included - NDHWGC layout has all groups within each batch
305 input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0] *
307 output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] *
309
310 // Update GemmM to use split N (not original N)
313 }
314
316 decltype(ConvToGemmFwdTransformer{}
317 .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
319 decltype(ConvToGemmFwdTransformer{}
320 .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
322 decltype(ConvToGemmFwdTransformer{}
323 .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
324
325 static constexpr index_t NonSpatialDims = 3;
329
334
340
341 const void* in_ptr;
342 const void* wei_ptr;
343 std::array<const void*, NumDTensor> ds_ptr;
345 void* out_ptr;
346
350
354
355 // Split-N support fields - initialize to safe defaults
356 index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
357 index_t n_per_split = 1; // Batches per split (N_ from transformer)
358 index_t original_n = 1; // Original batch size before splitting
359 index_t input_batch_stride = 0; // Stride to next batch in input tensor
360 index_t output_batch_stride = 0; // Stride to next batch in output tensor
361
362 // Split-image support - spatial offsets (applied per-batch in operator())
363 long_index_t spatial_offset_in = 0; // Spatial offset for input (e.g., W/2 for 1D split)
364 long_index_t spatial_offset_out = 0; // Spatial offset for output (e.g., W/2 for 1D split)
365
366 // Split-image support - transformer instance
368
369 // Forward declare descriptor types (will be defined after using declarations)
373
374 // Split-image support: Common data for all pieces
376 {
377 // Common dimensions (same for all pieces)
378 index_t total_d = 1, total_h = 1, total_w = 1; // Total tensor dimensions
379 index_t total_spatial = 1; // Pre-calculated: total_d * total_h * total_w
380 index_t num_d_pieces = 1, num_h_pieces = 1, num_w_pieces = 1; // Split factors
381
382 // Minimal per-piece data (only unique values)
384 {
385 index_t block_start; // Starting block index for this piece
386 index_t block_end; // Ending block index (exclusive)
387 index_t d_start, h_start, w_start; // Piece starting position in OUTPUT space
388 index_t d_size, h_size, w_size; // Piece size in OUTPUT space
389 };
390
391 static constexpr index_t MaxPieces = 64; // Max pieces: 4 (1D), 16 (2D), 64 (3D)
392 std::array<PieceInfo, MaxPieces> pieces; // Array of minimal piece descriptors
393 };
394
395 index_t num_spatial_pieces = 1; // Number of spatial pieces (1 = no split)
396 SplitImageInfo split_image; // Nested structure with common + per-piece data
397};
398
437template <typename GroupedConvTraitsType_,
438 typename TilePartitioner_,
439 typename GemmPipeline_,
440 typename EpiloguePipeline_>
442{
443 static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
444 static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
446 GroupedConvTraitsType_::ConvSpecialization;
453
458
460 static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
461
462 static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
463
467 // Below type is actually accumulation data type - the output of block GEMM.
469
470 using CDElementwise = typename EpiloguePipeline::CDElementwise;
471
474
475 static constexpr bool IsSplitKSupported = false;
476
477 static constexpr auto I0 = number<0>();
478 static constexpr auto I1 = number<1>();
479 static constexpr auto I2 = number<2>();
480 static constexpr auto I3 = number<3>();
481
482 static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
483 "Not supported!");
484 static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
485 static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
486 static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
487
488 // Helper struct for spatial coordinates
490 {
492 };
493
494 // Helper: Convert flat spatial index to (d,h,w) coordinates
497 {
498 if constexpr(NDimSpatial == 1)
499 {
500 return SpatialCoords{0, 0, flat};
501 }
502 else if constexpr(NDimSpatial == 2)
503 {
504 return SpatialCoords{0, flat / w_size, flat % w_size};
505 }
506 else // NDimSpatial == 3
507 {
508 const index_t hw = h_size * w_size;
509 const index_t d = flat / hw;
510 const index_t remainder = flat % hw;
511 return SpatialCoords{d, remainder / w_size, remainder % w_size};
512 }
513 }
514
515 // Helper: Convert (d,h,w) to flat spatial index
518 {
519 if constexpr(NDimSpatial == 1)
520 {
521 return w;
522 }
523 else if constexpr(NDimSpatial == 2)
524 {
525 return h * total_w + w;
526 }
527 else // NDimSpatial == 3
528 {
529 return (d * total_h + h) * total_w + w;
530 }
531 }
532
533 // Helper: Find which piece owns a block using binary search
534 template <typename SplitImageInfo>
536 FindPieceId(index_t block_id, const SplitImageInfo& split_info, index_t num_pieces)
537 {
538 index_t left = 0;
539 index_t right = num_pieces - 1;
540 index_t piece_id = (left + right) / 2;
541
542 while(!(block_id >= split_info.pieces[piece_id].block_start &&
543 block_id < split_info.pieces[piece_id].block_end) &&
544 left <= right)
545 {
546 if(block_id < split_info.pieces[piece_id].block_start)
547 {
548 right = piece_id - 1;
549 }
550 else
551 {
552 left = piece_id + 1;
553 }
554 piece_id = (left + right) / 2;
555 }
556 return piece_id;
557 }
558
559 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
560 {
561 // clang-format off
562 return concat('_', "grouped_convolution_forward",
564 "gemm",
565 GemmPipeline::GetName(),
566 "epilogue",
567 EpiloguePipeline::GetName());
568 // clang-format on
569 }
570
572 {
573 return dim3(
574 TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits);
575 }
576
578 {
579 return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
580 }
581
584 {
585 auto kargs = GroupedConvFwdKernelArgsSpecialized(hostArgs);
586 return kargs;
587 }
588
590 {
591 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
592 }
593
595 {
596 if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
599 {
600 if(kargs.k_batch != 1)
601 {
602 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
603 {
604 CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
605 }
606 return false;
607 }
608 }
609
610 const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
611 const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
612
613 // check ConvolutionSpecialization
615 {
616 // check if it's 1x1, stride=1 conv
617 for(index_t i = 0; i < NDimSpatial; ++i)
618 {
619 const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
620 const index_t ConvStride = kargs.conv_filter_strides[i];
621 const index_t LeftPad = kargs.input_left_pads[i];
622 const index_t RightPad = kargs.input_right_pads[i];
623
624 if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
625 {
626 return false;
627 }
628 }
629 }
631 {
632 // check if it's 1x1 conv
633 for(index_t i = 0; i < NDimSpatial; ++i)
634 {
635 const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
636 const index_t LeftPad = kargs.input_left_pads[i];
637 const index_t RightPad = kargs.input_right_pads[i];
638
639 if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
640 {
641 return false;
642 }
643 }
644 }
646 {
647 if(ConvC != 1)
648 {
649 return false;
650 }
651 for(index_t i = 0; i < NDimSpatial; ++i)
652 {
653 const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
654
655 if(filter_spatial_dim != I3)
656 {
657 return false;
658 }
659 }
660 }
661
662 namespace ctc = tensor_layout::convolution;
663
664 if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
665 std::is_same_v<InLayout, ctc::NDHWGC>)
666 {
667 // Check access per C
668 if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0)
669 {
670 CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
671 return false;
672 }
673 }
674 else
675 {
676 CK_TILE_ERROR("Not supported input layout!");
677 return false;
678 }
679
680 // check vector access of B
681 // FIXME: layout
682 if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
683 std::is_same_v<WeiLayout, ctc::GKYXC> ||
684 std::is_same_v<WeiLayout, ctc::GKZYXC>)
685 {
686 if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
687 {
688 CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
689 return false;
690 }
691 }
692 else
693 {
694 CK_TILE_ERROR("Not supported weight layout!");
695 return false;
696 }
697
698 // check vector access of E
699 if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
700 std::is_same_v<OutLayout, ctc::NHWGK> ||
701 std::is_same_v<OutLayout, ctc::NDHWGK>)
702 {
703 if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0)
704 {
705 CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
706 return false;
707 }
708 }
709 else
710 {
711 CK_TILE_ERROR("Not supported output layout!");
712 return false;
713 }
714
715 return true;
716 }
717
719 typename ADescType,
720 typename BDescType,
721 typename CDescType>
722 CK_TILE_DEVICE static auto
724 const WeiDataType* b_ptr,
725 const std::array<const void*, NumDTensor>& ds_ptr,
726 OutDataType* c_ptr,
727 const ADescType& a_desc,
728 const BDescType& b_desc,
729 const CDescType& c_desc)
730 {
731 static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
732 static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
733 const auto& a_tensor_view = [&]() {
735 }();
736
737 const auto& b_tensor_view = [&]() {
739 }();
740
741 // TODO: enable vector write for C in ColMajor
742 const auto& c_tensor_view = [&]() {
744 }();
745
746 const auto& ds_tensor_view = generate_tuple(
747 [&](auto i) {
748 static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
749 "Not supported!");
750 static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
751 "Not supported!");
752 static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
753 "Not supported!");
754
756 static_cast<const OutDataType*>(ds_ptr[i]), c_desc);
757 },
759
760 return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
761 }
762
763 template <typename TensorView>
764 CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
765 {
766 const auto& a_pad_view = [&]() {
767 const auto& a_tensor_view = views.at(I0);
768 return pad_tensor_view(a_tensor_view,
772 }();
773
774 const auto& b_pad_view = [&]() {
775 const auto& b_tensor_view = views.at(I1);
776 return pad_tensor_view(b_tensor_view,
780 }();
781
782 const auto& ds_tensor_view = views.at(I2);
783 const auto& ds_pad_view = generate_tuple(
784 [&](auto i) {
785 return pad_tensor_view(ds_tensor_view[i],
789 },
791
792 const auto& c_pad_view = [&]() {
793 const auto& c_tensor_view = views.at(I3);
794 return pad_tensor_view(c_tensor_view,
798 }();
799
800 return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
801 }
802
803 template <typename PadView>
804 CK_TILE_DEVICE static auto
805 MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
806 {
807 const auto& a_pad_view = views.at(I0);
808 const auto& b_pad_view = views.at(I1);
809 const auto& ds_pad_view = views.at(I2);
810 const auto& c_pad_view = views.at(I3);
811
812 const auto& a_block_window = [&]() {
813 return make_tile_window(a_pad_view,
816 {i_m, 0});
817 }();
818
819 const auto& b_block_window = [&]() {
820 return make_tile_window(b_pad_view,
823 {i_n, 0});
824 }();
825
826 const auto ds_block_window = generate_tuple(
827 [&](auto i) {
828 return make_tile_window(ds_pad_view[i],
831 {i_m, i_n});
832 },
834
835 auto c_block_window = make_tile_window(
836 c_pad_view,
838 {i_m, i_n});
839
840 return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
841 }
842
859 template <typename ADescType, typename BDescType, typename CDescType>
860 CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
861 const WeiDataType* b_ptr,
862 const std::array<const void*, NumDTensor>& ds_ptr,
863 OutDataType* c_ptr,
864 void* smem_ptr_0,
865 const ADescType& a_desc,
866 const BDescType& b_desc,
867 const CDescType& c_desc,
868 const index_t gemm_k,
869 const index_t block_idx_m,
870 const index_t block_idx_n)
871 {
872 // Create Gemm tensor views, pad views and tile windows
873 const auto& gemm_tensor_views_tuple =
875 a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
876
877 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
878 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
879
880 const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
881
882 // Run GEMM cooperatively by whole workgroup.
883 const auto& a_block_window = gemm_tile_windows.at(I0);
884 const auto& b_block_window = gemm_tile_windows.at(I1);
885 const auto& d_block_window = gemm_tile_windows.at(I2);
886
887 const auto& c_block_tile = GemmPipeline{}.template operator()(
888 a_block_window, b_block_window, num_loop, smem_ptr_0);
889
890 // Run Epilogue Pipeline
891 auto& c_block_window = gemm_tile_windows.at(I3);
892
893 EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
894 c_block_window, c_block_tile, d_block_window, smem_ptr_0);
895 }
896
916 template <typename ADescType, typename BDescType, typename CDescType>
917 CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
918 const WeiDataType* b_ptr,
919 const std::array<const void*, NumDTensor>& ds_ptr,
920 OutDataType* c_ptr,
921 void* __restrict__ smem_ptr_0,
922 void* __restrict__ smem_ptr_1,
923 const ADescType& a_desc,
924 const BDescType& b_desc,
925 const CDescType& c_desc,
926 const index_t gemm_k,
927 const index_t block_idx_m,
928 const index_t block_idx_n)
929 {
930 // Create Gemm tensor views, pad views and tile windows
931 const auto& gemm_tensor_views_tuple =
933 a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
934 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
935 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
936
937 const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
938
939 // Run GEMM cooperatively by whole workgroup.
940 const auto& a_block_window = gemm_tile_windows.at(I0);
941 const auto& b_block_window = gemm_tile_windows.at(I1);
942 const auto& d_block_window = gemm_tile_windows.at(I2);
943
944 const auto& c_block_tile = GemmPipeline{}.template operator()(
945 a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
946
947 // Run Epilogue Pipeline
948 auto& c_block_window = gemm_tile_windows.at(I3);
949
950 EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
951 c_block_window, c_block_tile, d_block_window, smem_ptr_0);
952 }
953
955 {
956 const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
957 const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
958
959 const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
960 const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
961 const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
962
963 // Split-N handling: Get which split this workgroup handles
964 const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
965
966 // Calculate batch offset for this split
967 const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split);
968
969 // Calculate memory offsets for this split
970 const long_index_t input_batch_offset = static_cast<long_index_t>(batch_offset) *
971 static_cast<long_index_t>(kargs.input_batch_stride);
972 const long_index_t output_batch_offset =
973 static_cast<long_index_t>(batch_offset) *
974 static_cast<long_index_t>(kargs.output_batch_stride);
975
976 // Calculate base pointers with group and batch offsets
977 const InDataType* base_a_ptr =
978 static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
979 const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
980 group_offset_b; // No batch offset for weights!
981 OutDataType* base_c_ptr =
982 static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
983
984 // =====================================================================
985 // Split-image: Map local block to global tile index (if enabled)
986 // =====================================================================
987 const InDataType* a_ptr;
988 OutDataType* c_ptr;
989 index_t i_m = 0;
990 index_t i_n = 0;
991
992 // Pre-calculate block_id (used in both split-image and non-split paths)
993 const index_t block_id = static_cast<index_t>(blockIdX);
994
995 if constexpr(EnableSplitImage)
996 {
997 // Add spatial offsets for split-image (constexpr optimization)
998 a_ptr = base_a_ptr + kargs.spatial_offset_in;
999 c_ptr = base_c_ptr + kargs.spatial_offset_out;
1000
1001 // Find which piece owns this block using binary search
1002 // Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
1003 const index_t piece_id =
1004 FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
1005 const auto& piece = kargs.split_image.pieces[piece_id];
1006 const auto& split_info = kargs.split_image;
1007
1008 // Calculate local block ID and tile indices
1009 const index_t local_block_id = block_id - piece.block_start;
1010 const index_t local_gemm_m =
1011 kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
1012 const auto [local_tile_m, local_tile_n] =
1013 TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
1014
1015 // Extract batch and spatial coordinates from local tile
1016 const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
1017 const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
1018 const index_t local_n = local_m_start / spatial_per_batch;
1019 const index_t local_spatial_flat = local_m_start % spatial_per_batch;
1020
1021 // Convert to local spatial coordinates
1022 const auto local_coords =
1023 UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
1024
1025 // Convert to global spatial coordinates
1026 const index_t global_n = local_n;
1027 const index_t global_d = piece.d_start + local_coords.d;
1028 const index_t global_h = piece.h_start + local_coords.h;
1029 const index_t global_w = piece.w_start + local_coords.w;
1030
1031 // Convert to global M index
1032 const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
1033 const index_t global_spatial_flat = FlattenSpatial(
1034 global_d, global_h, global_w, split_info.total_h, split_info.total_w);
1035 const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
1036
1037 // Set tile indices for GEMM operation
1038 i_m = amd_wave_read_first_lane(global_m);
1039 i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
1040 }
1041 else
1042 {
1043 // No spatial offsets needed for regular path
1044 a_ptr = base_a_ptr;
1045 c_ptr = base_c_ptr;
1046
1047 // No split-image: use standard tile partitioning
1048 const auto [iM, iN] =
1049 TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
1050 i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1051 i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1052 }
1053
1054 // Use global descriptors for all cases
1055 const auto& a_desc = kargs.a_grid_desc_m_k;
1056 const auto& b_desc = kargs.b_grid_desc_n_k;
1057 const auto& c_desc = kargs.c_grid_desc_m_n;
1058
1059 // allocate LDS
1060 __shared__ char smem_ptr_0[GetSmemSize()];
1061
1062 if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1063 {
1064 __shared__ char smem_ptr_1[GetSmemSize()];
1065 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1066 GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1068 {
1069 RunGemm2LDS(a_ptr,
1070 b_ptr,
1071 kargs.ds_ptr,
1072 c_ptr,
1073 smem_ptr_0,
1074 smem_ptr_1,
1075 a_desc,
1076 b_desc,
1077 c_desc,
1078 kargs.GemmK,
1079 i_m,
1080 i_n);
1081 }
1082 }
1083 else
1084 {
1085 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1086 GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1088 {
1089 RunGemm(a_ptr,
1090 b_ptr,
1091 kargs.ds_ptr,
1092 c_ptr,
1093 smem_ptr_0,
1094 a_desc,
1095 b_desc,
1096 c_desc,
1097 kargs.GemmK,
1098 i_m,
1099 i_n);
1100 }
1101 }
1102 }
1103};
1104
1105} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/ops/common/tensor_layout.hpp:27
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
bool EnvIsEnabled(EnvVar)
Definition tile/core/utility/env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
ConvolutionSpecialization
Definition convolution_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_specialization.hpp:14
@ Filter3x3
Definition convolution_specialization.hpp:15
@ Filter1x1Pad0
Definition convolution_specialization.hpp:13
memory_operation_enum
Definition arch.hpp:56
@ atomic_add
Definition arch.hpp:58
@ set
Definition arch.hpp:57
GroupedConvHostArgs< const void *, const void *, void *, CDElementwise > GroupedConvFwdHostArgs
Definition grouped_convolution_utils.hpp:50
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition grouped_convolution_forward_kernel.hpp:384
index_t w_size
Definition grouped_convolution_forward_kernel.hpp:388
index_t h_start
Definition grouped_convolution_forward_kernel.hpp:387
index_t w_start
Definition grouped_convolution_forward_kernel.hpp:387
index_t d_size
Definition grouped_convolution_forward_kernel.hpp:388
index_t h_size
Definition grouped_convolution_forward_kernel.hpp:388
index_t block_start
Definition grouped_convolution_forward_kernel.hpp:385
index_t block_end
Definition grouped_convolution_forward_kernel.hpp:386
index_t d_start
Definition grouped_convolution_forward_kernel.hpp:387
Definition grouped_convolution_forward_kernel.hpp:376
index_t num_d_pieces
Definition grouped_convolution_forward_kernel.hpp:380
index_t total_w
Definition grouped_convolution_forward_kernel.hpp:378
index_t total_d
Definition grouped_convolution_forward_kernel.hpp:378
std::array< PieceInfo, MaxPieces > pieces
Definition grouped_convolution_forward_kernel.hpp:392
static constexpr index_t MaxPieces
Definition grouped_convolution_forward_kernel.hpp:391
index_t total_spatial
Definition grouped_convolution_forward_kernel.hpp:379
index_t num_w_pieces
Definition grouped_convolution_forward_kernel.hpp:380
index_t total_h
Definition grouped_convolution_forward_kernel.hpp:378
index_t num_h_pieces
Definition grouped_convolution_forward_kernel.hpp:380
The Grouped Convolution kernel device arguments.
Definition grouped_convolution_forward_kernel.hpp:24
long_index_t group_stride_c
Definition grouped_convolution_forward_kernel.hpp:353
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K< typename GroupedConvTraitsType_::InLayout >())> AGridDescMK
Definition grouped_convolution_forward_kernel.hpp:315
index_t input_batch_stride
Definition grouped_convolution_forward_kernel.hpp:359
static constexpr index_t NonSpatialDims
Definition grouped_convolution_forward_kernel.hpp:325
index_t n_per_split
Definition grouped_convolution_forward_kernel.hpp:357
const CDElementwise elfunc
Definition grouped_convolution_forward_kernel.hpp:344
AGridDescMK a_grid_desc_m_k
Definition grouped_convolution_forward_kernel.hpp:347
TransformConvFwdToGemm< GroupedConvTraitsType_::NDimSpatial, GroupedConvTraitsType_::ConvSpecialization, GroupedConvTraitsType_::VectorSizeA, GroupedConvTraitsType_::VectorSizeB, GroupedConvTraitsType_::VectorSizeC, GroupedConvTraitsType_::NumGroupsToMerge, true > ConvToGemmFwdTransformer
Definition grouped_convolution_forward_kernel.hpp:26
CGridDescMN CGridDescMN_t
Definition grouped_convolution_forward_kernel.hpp:372
const void * in_ptr
Definition grouped_convolution_forward_kernel.hpp:341
index_t GemmM
Definition grouped_convolution_forward_kernel.hpp:336
index_t original_n
Definition grouped_convolution_forward_kernel.hpp:358
long_index_t group_stride_b
Definition grouped_convolution_forward_kernel.hpp:352
ConvToGemmFwdTransformer ConvToGemmFwdTransformer_t
Definition grouped_convolution_forward_kernel.hpp:370
CGridDescMN c_grid_desc_m_n
Definition grouped_convolution_forward_kernel.hpp:349
CDElementwise_ CDElementwise
Definition grouped_convolution_forward_kernel.hpp:34
index_t n_splits
Definition grouped_convolution_forward_kernel.hpp:356
std::array< const void *, NumDTensor > ds_ptr
Definition grouped_convolution_forward_kernel.hpp:343
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition grouped_convolution_forward_kernel.hpp:332
AGridDescMK AGridDescMK_t
Definition grouped_convolution_forward_kernel.hpp:371
const void * wei_ptr
Definition grouped_convolution_forward_kernel.hpp:342
BGridDescNK b_grid_desc_n_k
Definition grouped_convolution_forward_kernel.hpp:348
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeBDescriptor_N_K< typename GroupedConvTraitsType_::WeiLayout >())> BGridDescNK
Definition grouped_convolution_forward_kernel.hpp:318
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeCDescriptor_M_N< typename GroupedConvTraitsType_::OutLayout >())> CGridDescMN
Definition grouped_convolution_forward_kernel.hpp:321
index_t num_spatial_pieces
Definition grouped_convolution_forward_kernel.hpp:395
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition grouped_convolution_forward_kernel.hpp:328
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition grouped_convolution_forward_kernel.hpp:327
index_t GemmN
Definition grouped_convolution_forward_kernel.hpp:337
long_index_t spatial_offset_in
Definition grouped_convolution_forward_kernel.hpp:363
SplitImageInfo split_image
Definition grouped_convolution_forward_kernel.hpp:396
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs< CDElementwise > &args)
Definition grouped_convolution_forward_kernel.hpp:45
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition grouped_convolution_forward_kernel.hpp:333
index_t output_batch_stride
Definition grouped_convolution_forward_kernel.hpp:360
long_index_t group_stride_a
Definition grouped_convolution_forward_kernel.hpp:351
index_t GemmK
Definition grouped_convolution_forward_kernel.hpp:338
void * out_ptr
Definition grouped_convolution_forward_kernel.hpp:345
ConvToGemmFwdTransformer transformer_
Definition grouped_convolution_forward_kernel.hpp:367
index_t GemmBatch
Definition grouped_convolution_forward_kernel.hpp:339
long_index_t spatial_offset_out
Definition grouped_convolution_forward_kernel.hpp:364
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition grouped_convolution_forward_kernel.hpp:326
static constexpr index_t NumDTensor
Definition grouped_convolution_forward_kernel.hpp:35
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition grouped_convolution_forward_kernel.hpp:331
index_t k_batch
Definition grouped_convolution_forward_kernel.hpp:335
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition grouped_convolution_forward_kernel.hpp:330
InPtr in_ptr
Definition grouped_convolution_utils.hpp:39
OutPtr out_ptr
Definition grouped_convolution_utils.hpp:42
WeiPtr wei_ptr
Definition grouped_convolution_utils.hpp:40
index_t k_batch
Definition grouped_convolution_utils.hpp:43
const std::vector< const void * > ds_ptr
Definition grouped_convolution_utils.hpp:41
Definition grouped_convolution_forward_kernel.hpp:490
index_t h
Definition grouped_convolution_forward_kernel.hpp:491
index_t d
Definition grouped_convolution_forward_kernel.hpp:491
index_t w
Definition grouped_convolution_forward_kernel.hpp:491
The Grouped Convolution Forward kernel template.
Definition grouped_convolution_forward_kernel.hpp:442
static CK_TILE_DEVICE index_t FindPieceId(index_t block_id, const SplitImageInfo &split_info, index_t num_pieces)
Definition grouped_convolution_forward_kernel.hpp:536
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition grouped_convolution_forward_kernel.hpp:459
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition grouped_convolution_forward_kernel.hpp:448
static CK_TILE_HOST constexpr GroupedConvFwdKernelArgsSpecialized MakeKernelArgs(const GroupedConvFwdHostArgs< CDElementwise > &hostArgs)
Definition grouped_convolution_forward_kernel.hpp:583
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition grouped_convolution_forward_kernel.hpp:447
typename EpiloguePipeline::CDElementwise CDElementwise
Definition grouped_convolution_forward_kernel.hpp:470
static constexpr auto I1
Definition grouped_convolution_forward_kernel.hpp:478
static constexpr auto I2
Definition grouped_convolution_forward_kernel.hpp:479
static CK_TILE_DEVICE index_t FlattenSpatial(index_t d, index_t h, index_t w, index_t total_h, index_t total_w)
Definition grouped_convolution_forward_kernel.hpp:517
static CK_TILE_DEVICE void RunGemm(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *smem_ptr_0, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc, const index_t gemm_k, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_convolution_forward_kernel.hpp:860
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition grouped_convolution_forward_kernel.hpp:456
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition grouped_convolution_forward_kernel.hpp:589
static CK_TILE_DEVICE auto MakeGemmTensorViews(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc)
Definition grouped_convolution_forward_kernel.hpp:723
static constexpr auto I0
Definition grouped_convolution_forward_kernel.hpp:477
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
Definition grouped_convolution_forward_kernel.hpp:954
static constexpr bool EnableSplitImage
Definition grouped_convolution_forward_kernel.hpp:443
GroupedConvFwdKernelArgs< GroupedConvTraitsType_, CDElementwise > GroupedConvFwdKernelArgsSpecialized
Definition grouped_convolution_forward_kernel.hpp:472
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition grouped_convolution_forward_kernel.hpp:455
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition grouped_convolution_forward_kernel.hpp:468
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition grouped_convolution_forward_kernel.hpp:457
static constexpr index_t kBlockSize
Definition grouped_convolution_forward_kernel.hpp:462
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition grouped_convolution_forward_kernel.hpp:466
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition grouped_convolution_forward_kernel.hpp:451
static constexpr index_t NDimSpatial
Definition grouped_convolution_forward_kernel.hpp:444
static CK_TILE_HOST auto BlockSize()
Definition grouped_convolution_forward_kernel.hpp:577
static constexpr auto I3
Definition grouped_convolution_forward_kernel.hpp:480
static CK_TILE_HOST const std::string GetName()
Definition grouped_convolution_forward_kernel.hpp:559
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition grouped_convolution_forward_kernel.hpp:594
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition grouped_convolution_forward_kernel.hpp:465
static constexpr index_t NumDTensor
Definition grouped_convolution_forward_kernel.hpp:460
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition grouped_convolution_forward_kernel.hpp:805
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition grouped_convolution_forward_kernel.hpp:450
static constexpr bool IsSplitKSupported
Definition grouped_convolution_forward_kernel.hpp:475
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition grouped_convolution_forward_kernel.hpp:452
static CK_TILE_DEVICE void RunGemm2LDS(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc, const index_t gemm_k, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_convolution_forward_kernel.hpp:917
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition grouped_convolution_forward_kernel.hpp:764
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition grouped_convolution_forward_kernel.hpp:454
static constexpr ConvolutionSpecialization ConvSpecialization
Definition grouped_convolution_forward_kernel.hpp:445
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition grouped_convolution_forward_kernel.hpp:449
static CK_TILE_DEVICE SpatialCoords UnflattenSpatial(index_t flat, index_t h_size, index_t w_size)
Definition grouped_convolution_forward_kernel.hpp:496
static CK_TILE_HOST auto GridSize(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition grouped_convolution_forward_kernel.hpp:571
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition grouped_convolution_forward_kernel.hpp:464
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:28
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition tile/host/convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition tile/host/convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition tile/host/convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition tile/host/convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition tile/host/convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition tile/host/convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition tile/host/convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition tile/host/convolution_parameter.hpp:134
Definition type_traits.hpp:115
Definition tile/core/container/sequence.hpp:49
#define CK_TILE_ENV(name)
Definition tile/core/utility/env.hpp:145