transform_conv_bwd_weight_to_gemm.hpp Source File

transform_conv_bwd_weight_to_gemm.hpp Source File#

Composable Kernel: transform_conv_bwd_weight_to_gemm.hpp Source File
tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp
Go to the documentation of this file.
1
2// SPDX-License-Identifier: MIT
3// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
4
5#pragma once
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11template <index_t NDimSpatial,
13 index_t VectorSizeA,
14 index_t VectorSizeB,
15 index_t VectorSizeC,
16 index_t NumGroupsToMerge = 1,
17 bool SplitN = false,
18 typename ADataType = float,
19 typename CDataType = float,
20 typename IndexType = index_t>
22{
23 private:
24 static constexpr auto I0 = number<0>{};
25 static constexpr auto I1 = number<1>{};
26 static constexpr auto I2 = number<2>{};
27 static constexpr auto I3 = number<3>{};
28 static constexpr auto I4 = number<4>{};
29 static constexpr auto I5 = number<5>{};
30#if 0 // TODO: Enable these functionalities
31 template <typename ConvDimsType>
32 static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
33 const ConvDimsType& strides,
34 index_t i)
35 {
36 long_index_t acc = 1;
37 for(; i < (NDimSpatial + 3); i++)
38 {
39 acc +=
40 static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
41 }
42
43 return acc;
44 }
45
46 template <typename ConvDimsType>
47 static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
48 const ConvDimsType& a_g_n_c_wis_strides,
49 const ConvDimsType& c_g_n_k_wos_lengths,
50 const ConvDimsType& c_g_n_k_wos_strides)
51 {
52 const long_index_t a_element_space_size =
53 calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
54 const long_index_t c_element_space_size =
55 calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
56 const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
57 c_element_space_size * sizeof(CDataType));
58 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
59
60 const IndexType N = a_g_n_c_wis_lengths[I1];
61
62 if(element_space_size > TwoGB)
63 {
64 // Minimum divisor of N to not exceed 2GB
65 const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
66
67 if(divisor <= static_cast<double>(N))
68 {
69 // Find least divisor of N larger than element_space_size / TwoGB
70 // Iterate up to sqrt(N). There are no divisors above this value.
71 for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
72 least_divisor++)
73 {
74 if(N % least_divisor == 0)
75 {
76 return N / least_divisor;
77 }
78 }
79 // Not found, process one Convolution N per block
80 return 1;
81 }
82 else
83 {
84 // Split Convolution's N dimension into N workgroups. However
85 // this still might not result in sufficiently small tensor,
86 // but at least later on we could divide the image as well.
87 return 1;
88 }
89 }
90 else
91 {
92 // Split N is not needed.
93 return N;
94 }
95 }
96#endif
97
98 public:
100
101 template <typename TransformConvBwdWeightToGemmBase>
103 const TransformConvBwdWeightToGemmBase& transform_conv_fwd_to_gemm_base)
104 : G_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.G_)},
105 N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
106 Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
107 Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
108 Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
109 Do_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Do_)},
110 Ho_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Ho_)},
111 Wo_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wo_)},
112 Z_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Z_)},
113 Y_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Y_)},
114 X_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.X_)},
115 K_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.K_)},
116 C_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.C_)},
117 ConvStrideD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideD_)},
118 ConvStrideH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideH_)},
119 ConvStrideW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideW_)},
120 ConvDilationD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationD_)},
121 ConvDilationH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationH_)},
122 ConvDilationW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationW_)},
123 InLeftPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadD_)},
124 InLeftPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadH_)},
125 InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
126 InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
127 InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
128 InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)}
129 {
130 }
131
132 template <typename ConvDimsType,
133 typename ConvSpatialDimsType,
134 index_t NDim = NDimSpatial,
135 typename std::enable_if<NDim == 1, bool>::type = false>
136 CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
137 const ConvDimsType& b_g_k_c_xs_lengths,
138 const ConvDimsType& c_g_n_k_wos_lengths,
139 const ConvSpatialDimsType& conv_filter_strides,
140 const ConvSpatialDimsType& conv_filter_dilations,
141 const ConvSpatialDimsType& input_left_pads,
142 const ConvSpatialDimsType& input_right_pads)
143 : G_{a_g_n_c_wis_lengths[I0]},
144 Di_{I1},
145 Hi_{I1},
146 Wi_{a_g_n_c_wis_lengths[I3]},
147 Do_{I1},
148 Ho_{I1},
149 Wo_{c_g_n_k_wos_lengths[I3]},
150 Z_{I1},
151 Y_{I1},
152 X_{b_g_k_c_xs_lengths[I3]},
153 K_{c_g_n_k_wos_lengths[I2]},
154 C_{b_g_k_c_xs_lengths[I2]},
155 ConvStrideD_{I1},
156 ConvStrideH_{I1},
157 ConvStrideW_{conv_filter_strides[I0]},
158 ConvDilationD_{I1},
159 ConvDilationH_{I1},
160 ConvDilationW_{conv_filter_dilations[I0]},
161 InLeftPadD_{I0},
162 InLeftPadH_{I0},
163 InLeftPadW_{input_left_pads[I0]},
164 InRightPadD_{I0},
165 InRightPadH_{I0},
166 InRightPadW_{input_right_pads[I0]}
167 {
168 static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
169 std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
170 static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
171 std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
172#if 0 // TODO: Enable these functionalities
173 if constexpr(SplitN)
174 {
175 N_ = GetSplitedNSize(
176 a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
177 }
178 else
179 {
180 N_ = c_g_n_k_wos_lengths[I1];
181 }
182#endif
183 N_ = c_g_n_k_wos_lengths[I1];
184 }
185
186 template <typename ConvDimsType,
187 typename ConvSpatialDimsType,
188 index_t NDim = NDimSpatial,
189 typename std::enable_if<NDim == 2, bool>::type = false>
190 CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
191 const ConvDimsType& b_g_k_c_xs_lengths,
192 const ConvDimsType& c_g_n_k_wos_lengths,
193 const ConvSpatialDimsType& conv_filter_strides,
194 const ConvSpatialDimsType& conv_filter_dilations,
195 const ConvSpatialDimsType& input_left_pads,
196 const ConvSpatialDimsType& input_right_pads)
197 : G_{a_g_n_c_wis_lengths[I0]},
198 Di_{I1},
199 Hi_{a_g_n_c_wis_lengths[I3]},
200 Wi_{a_g_n_c_wis_lengths[I4]},
201 Do_{I1},
202 Ho_{c_g_n_k_wos_lengths[I3]},
203 Wo_{c_g_n_k_wos_lengths[I4]},
204 Z_{I1},
205 Y_{b_g_k_c_xs_lengths[I3]},
206 X_{b_g_k_c_xs_lengths[I4]},
207 K_{c_g_n_k_wos_lengths[I2]},
208 C_{b_g_k_c_xs_lengths[I2]},
209 ConvStrideD_{I1},
210 ConvStrideH_{conv_filter_strides[I0]},
211 ConvStrideW_{conv_filter_strides[I1]},
212 ConvDilationD_{I1},
213 ConvDilationH_{conv_filter_dilations[I0]},
214 ConvDilationW_{conv_filter_dilations[I1]},
215 InLeftPadD_{I0},
216 InLeftPadH_{input_left_pads[I0]},
217 InLeftPadW_{input_left_pads[I1]},
218 InRightPadD_{I0},
219 InRightPadH_{input_right_pads[I0]},
220 InRightPadW_{input_right_pads[I1]}
221 {
222 static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
223 std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
224 static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
225 std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
226#if 0 // TODO: Enable these functionalities
227 if constexpr(SplitN)
228 {
229 N_ = GetSplitedNSize(
230 a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
231 }
232 else
233 {
234 N_ = c_g_n_k_wos_lengths[I1];
235 }
236#endif
237 N_ = c_g_n_k_wos_lengths[I1];
238 }
239
240 template <typename ConvDimsType,
241 typename ConvSpatialDimsType,
242 index_t NDim = NDimSpatial,
243 typename std::enable_if<NDim == 3, bool>::type = false>
244 CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
245 const ConvDimsType& b_g_k_c_xs_lengths,
246 const ConvDimsType& c_g_n_k_wos_lengths,
247 const ConvSpatialDimsType& conv_filter_strides,
248 const ConvSpatialDimsType& conv_filter_dilations,
249 const ConvSpatialDimsType& input_left_pads,
250 const ConvSpatialDimsType& input_right_pads)
251 : G_{a_g_n_c_wis_lengths[I0]},
252 Di_{a_g_n_c_wis_lengths[I3]},
253 Hi_{a_g_n_c_wis_lengths[I4]},
254 Wi_{a_g_n_c_wis_lengths[I5]},
255 Do_{c_g_n_k_wos_lengths[I3]},
256 Ho_{c_g_n_k_wos_lengths[I4]},
257 Wo_{c_g_n_k_wos_lengths[I5]},
258 Z_{b_g_k_c_xs_lengths[I3]},
259 Y_{b_g_k_c_xs_lengths[I4]},
260 X_{b_g_k_c_xs_lengths[I5]},
261 K_{c_g_n_k_wos_lengths[I2]},
262 C_{b_g_k_c_xs_lengths[I2]},
263 ConvStrideD_{conv_filter_strides[I0]},
264 ConvStrideH_{conv_filter_strides[I1]},
265 ConvStrideW_{conv_filter_strides[I2]},
266 ConvDilationD_{conv_filter_dilations[I0]},
267 ConvDilationH_{conv_filter_dilations[I1]},
268 ConvDilationW_{conv_filter_dilations[I2]},
269 InLeftPadD_{input_left_pads[I0]},
270 InLeftPadH_{input_left_pads[I1]},
271 InLeftPadW_{input_left_pads[I2]},
272 InRightPadD_{input_right_pads[I0]},
273 InRightPadH_{input_right_pads[I1]},
274 InRightPadW_{input_right_pads[I2]}
275 {
276 static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
277 std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
278 static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
279 std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
280#if 0 // TODO: Enable these functionalities
281 if constexpr(SplitN)
282 {
283 N_ = GetSplitedNSize(
284 a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
285 }
286 else
287 {
288 N_ = c_g_n_k_wos_lengths[I1];
289 }
290#endif
291 N_ = c_g_n_k_wos_lengths[I1];
292 }
293
294#if 0 // TODO: Enable these functionalities
295 __host__ bool AreDescriptorsSmallerThan2GB() const
296 {
297 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
298
299 const long_index_t in_desc_space_size =
300 I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
301 (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
302 const long_index_t out_desc_space_size =
303 I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
304 (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
305
306 bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
307 bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
308
309 return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
310 }
311
312 __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
313 CDataType* c_grid_ptr_base) const
314 {
315 // Create copies
316 auto conv_to_gemm_transformer_left = *this;
317 auto conv_to_gemm_transformer_right = *this;
318 IndexType a_right_offset = 0;
319 IndexType c_right_offset = 0;
320 // Calculate real filter size
321 const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
322 const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
323 const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
324 // Calculate start position in input for right tensor
325 const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
326 const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
327 const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
328 // Calculate last position in input for left tensor
329 const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
330 const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
331 const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
332 // Allow to split if whole left padding will be in left tensor and right padding in right
333 // tensor
334 const bool is_possible_to_split_d = Do_ != 1 &&
335 di_right_transformer_start_idx > InLeftPadD_ &&
336 di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
337 const bool is_possible_to_split_h = Ho_ != 1 &&
338 hi_right_transformer_start_idx > InLeftPadH_ &&
339 hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
340 const bool is_possible_to_split_w = Wo_ != 1 &&
341 wi_right_transformer_start_idx > InLeftPadW_ &&
342 wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
343
344 if(is_possible_to_split_d)
345 {
346 // Apply new sizes
347 // Split output on half
348 conv_to_gemm_transformer_left.Do_ = Do_ / 2;
349 conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
350 // Assign left padding to left convolution
351 conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
352 conv_to_gemm_transformer_right.InLeftPadD_ = 0;
353 // Assign right padding to right convolution
354 conv_to_gemm_transformer_left.InRightPadD_ = 0;
355 conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
356 // Calculate new input size
357 conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
358 conv_to_gemm_transformer_right.Di_ =
359 math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
360 (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
361 ;
362 // Calcualte offsets
363 a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
364 c_right_offset = (Do_ / 2) * DoStride_;
365 }
366 else if(is_possible_to_split_h)
367 {
368 conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
369 conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
370
371 conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
372 conv_to_gemm_transformer_right.InLeftPadH_ = 0;
373
374 conv_to_gemm_transformer_left.InRightPadH_ = 0;
375 conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
376
377 conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
378 conv_to_gemm_transformer_right.Hi_ =
379 math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
380 (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
381 a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
382 c_right_offset = (Ho_ / 2) * HoStride_;
383 }
384 else if(is_possible_to_split_w)
385 {
386 conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
387 conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
388
389 conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
390 conv_to_gemm_transformer_right.InLeftPadW_ = 0;
391
392 conv_to_gemm_transformer_left.InRightPadW_ = 0;
393 conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
394
395 conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
396 conv_to_gemm_transformer_right.Wi_ =
397 math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
398 (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
399
400 a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
401 c_right_offset = (Wo_ / 2) * WoStride_;
402 }
403 // Return left transform, right transformer, right offset to Input and right offset to
404 // Output
405 return ck_tile::make_tuple(conv_to_gemm_transformer_left,
406 conv_to_gemm_transformer_right,
407 a_grid_ptr_base + a_right_offset,
408 c_grid_ptr_base + c_right_offset);
409 }
410#endif
411
412 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
414 {
415 // NWGK
416 const index_t NDoHoWoStride = G_ * K_;
417 constexpr auto KStride = I1;
418
419 if constexpr(NumGroupsToMerge > 1)
420 {
421 const index_t BatchStride = K_;
422 return make_naive_tensor_descriptor(make_tuple(K_, NumGroupsToMerge, N_ * Wo_),
423 make_tuple(KStride, BatchStride, NDoHoWoStride),
425 I1);
426 }
427 else
428 {
430 make_tuple(KStride, NDoHoWoStride),
432 I1);
433 }
434 }
435
436 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
438 {
439 // NWGC
440 const index_t NStride = Wi_ * G_ * C_;
441 const index_t WiStride = G_ * C_;
442 constexpr auto CStride = I1;
443
444 if constexpr(NumGroupsToMerge > 1)
445 {
446 const auto BatchStride = C_;
447 return make_naive_tensor_descriptor(make_tuple(N_, Wi_, NumGroupsToMerge, C_),
448 make_tuple(NStride, WiStride, BatchStride, CStride),
450 I1);
451 }
452 else
453 {
454
456 make_tuple(NStride, WiStride, CStride),
458 I1);
459 }
460 }
461
462 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
464 {
465 // GKXC
466 const index_t KStride = X_ * C_;
467 constexpr auto CXStride = I1;
468
469 if constexpr(NumGroupsToMerge > 1)
470 {
471 const index_t XStride = C_;
472 const index_t BatchStride = K_ * X_ * C_;
473 // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
474 // for Batch+N dimension
475 const auto desc = make_naive_tensor_descriptor(
476 make_tuple(NumGroupsToMerge, K_, X_, 1, C_),
477 make_tuple(BatchStride, KStride, XStride, BatchStride, CXStride),
479 I1);
480 // Pad 1 to NumGroupsToMerge
481 const auto padded_desc = transform_tensor_descriptor(
482 desc,
483 make_tuple(make_pass_through_transform(NumGroupsToMerge),
486 make_pad_transform(1, 0, NumGroupsToMerge - 1),
492 // We need only matrices from diagonal. Xor returns 0 for the same
493 // values. So if matrices is not on diagonal then it will be stored in padding.
494 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
495 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
496 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
497 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
498 const auto unmerged_padded_desc = transform_tensor_descriptor(
499 padded_desc,
500 make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
506 // Merge To M, N
508 unmerged_padded_desc,
509 make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
510 make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_))),
513 }
514 else
515 {
517 make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
518 }
519 }
520
521 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
523 {
524 // NHWGK
525 const index_t NDoHoWoStride = G_ * K_;
526 constexpr auto KStride = I1;
527
528 if constexpr(NumGroupsToMerge > 1)
529 {
530 const index_t BatchStride = K_;
532 make_tuple(N_ * Ho_ * Wo_, NumGroupsToMerge, K_), // K_Gm_M
533 make_tuple(NDoHoWoStride, BatchStride, KStride),
535 I1);
536 }
537 else
538 {
540 make_tuple(NDoHoWoStride, KStride),
542 I1);
543 }
544 }
545
546 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
548 {
549 // NHWGC
550 const index_t NStride = Hi_ * Wi_ * G_ * C_;
551 const index_t HiStride = Wi_ * G_ * C_;
552 const index_t WiStride = G_ * C_;
553 constexpr auto CStride = I1;
554
555 if constexpr(NumGroupsToMerge > 1)
556 {
557 const auto BatchStride = C_;
559 make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), // K_Gm_N
560 make_tuple(NStride, HiStride, WiStride, BatchStride, CStride),
562 I1);
563 }
564 else
565 {
567 make_tuple(NStride, HiStride, WiStride, CStride),
569 I1);
570 }
571 }
572
573 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
575 {
576 // GKYXC
577 const index_t KStride = Y_ * X_ * C_;
578 constexpr auto CStride = I1;
579
580 if constexpr(NumGroupsToMerge > 1)
581 {
582 const index_t YXStride = C_;
583 const index_t BatchStride = K_ * Y_ * X_ * C_;
584 // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
585 // for Batch+N dimension
586 const auto desc = make_naive_tensor_descriptor(
587 make_tuple(NumGroupsToMerge, K_, Y_ * X_, 1, C_),
588 make_tuple(BatchStride, KStride, YXStride, BatchStride, CStride),
590 I1);
591 // Pad 1 to NumGroupsToMerge
592 const auto padded_desc = transform_tensor_descriptor(
593 desc,
594 make_tuple(make_pass_through_transform(NumGroupsToMerge),
597 make_pad_transform(1, 0, NumGroupsToMerge - 1),
603 // We need only matrices from diagonal. Xor returns 0 for the same
604 // values. So if matrices is not on diagonal then it will be stored in padding.
605 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
606 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
607 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
608 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
609 const auto unmerged_padded_desc = transform_tensor_descriptor(
610 padded_desc,
611 make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
617 // Merge To M, N
619 unmerged_padded_desc,
620 make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
621 make_merge_transform(make_tuple(Y_ * X_, NumGroupsToMerge, C_))),
624 }
625 else
626 {
628 make_tuple(KStride, CStride),
630 I1);
631 }
632 }
633
634 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
636 {
637 // NDHWGK
638 const index_t NDoHoWoStride = G_ * K_;
639 constexpr auto KStride = I1;
640
641 if constexpr(NumGroupsToMerge > 1)
642 {
643 const auto BatchStride = K_;
645 make_tuple(N_ * Do_ * Ho_ * Wo_, NumGroupsToMerge, K_),
646 make_tuple(NDoHoWoStride, BatchStride, KStride),
648 I1);
649 }
650 else
651 {
653 make_tuple(NDoHoWoStride, KStride),
655 I1);
656 }
657 }
658
659 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
661 {
662 const index_t NStride = Di_ * Hi_ * Wi_ * G_ * C_;
663 const index_t DiStride = Hi_ * Wi_ * G_ * C_;
664 const index_t HiStride = Wi_ * G_ * C_;
665 const index_t WiStride = G_ * C_;
666 constexpr auto CStride = I1;
667
668 if constexpr(NumGroupsToMerge > 1)
669 {
670 const index_t BatchStride = C_;
672 make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
673 make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride),
675 I1);
676 }
677 else
678 {
680 make_tuple(N_, Di_, Hi_, Wi_, C_),
681 make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
683 I1);
684 }
685 }
686
687 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
689 {
690 // GKZYXC
691 const index_t KStride = Z_ * Y_ * X_ * C_;
692 constexpr auto CStride = I1;
693
694 if constexpr(NumGroupsToMerge > 1)
695 {
696 const index_t ZYXStride = C_;
697 const index_t BatchStride = K_ * Z_ * Y_ * X_ * C_;
698 // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
699 // for Batch+N dimension
700 const auto desc = make_naive_tensor_descriptor(
701 make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_, 1, C_),
702 make_tuple(BatchStride, KStride, ZYXStride, BatchStride, CStride),
704 I1);
705 // Pad 1 to NumGroupsToMerge
706 const auto padded_desc = transform_tensor_descriptor(
707 desc,
708 make_tuple(make_pass_through_transform(NumGroupsToMerge),
711 make_pad_transform(1, 0, NumGroupsToMerge - 1),
717 // We need only matrices from diagonal. Xor returns 0 for the same
718 // values. So if matrices is not on diagonal then it will be stored in padding.
719 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
720 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
721 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
722 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
723 const auto unmerged_padded_desc = transform_tensor_descriptor(
724 padded_desc,
725 make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
731 // Merge To M, N
733 unmerged_padded_desc,
734 make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
735 make_merge_transform(make_tuple(Z_ * Y_ * X_, NumGroupsToMerge, C_))),
738 }
739 else
740 {
742 make_tuple(KStride, CStride),
744 I1);
745 }
746 }
747
748 // TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
749 // properties
750
751 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
753 {
754 const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
755 const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
756 const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
757
758 // B: input tensor comes in K_N
759 if constexpr(NumGroupsToMerge > 1)
760 {
761 // Output tensor transformation
762 // [0, 1, 2] -> [0, 1]
763 // [(N*Wo), Gm, K] -> [(N*Wo), (Gm*K)]
764 const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
765 out_grid_desc,
767 make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
770
771 // Input tensor transformation, part 1.
772 // [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C]
773 const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor(
774 in_grid_desc,
777 make_pass_through_transform(NumGroupsToMerge),
781
782 // Input tensor transformation, part 2.
783 // [N, Wip, Gm, C] -> [N, X, Wo, Gm, C]
784 const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
785 in_n_wip_gm_c_grid_desc,
789 make_pass_through_transform(NumGroupsToMerge),
793
794 // Input tensor transformation, part 3.
795 // [0, 1, 2, 3, 4] -> [0, 1]
796 // [N, X, Wo, Gm, C] -> [(N*Wo), (Gm*X*C)]
797 const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
798 in_n_x_wo_gm_c_grid_desc,
799 make_tuple(make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_)),
803
804 return make_tuple(
805 out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
806 }
807 else
808 {
809 // [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C]
810 const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
811 in_grid_desc,
817
818 // [N, Wip, C] -> [N, X, Wo, C]
819 const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
820 in_n_wip_c_grid_desc,
827
828 const auto in_gemmn_gemmktotal_grid_desc =
829 transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
834
835 return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
836 }
837 }
838
839 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
841 {
842 const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
843 const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
844 const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
845
846 // B: input tensor comes in K_N
847 if constexpr(NumGroupsToMerge > 1)
848 {
849 // Output tensor transformation
850 // [0, 1, 2] -> [0, 1]
851 // [(N*Ho*Wo), Gm, K] -> [(N*Ho*Wo), (K*Gm)]
852 const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
853 out_grid_desc,
855 make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
858
859 // Input tensor transformation, part 1.
860 // [N, Hi, Wi, Gm, C] -> [N, Hip, Wip, Gm, C]
861 const auto in_n_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
862 in_grid_desc,
866 make_pass_through_transform(NumGroupsToMerge),
872
873 // Input tensor transformation, part 2.
874 // [N, Hip, Wip, Gm, C] -> [N, (Y, Wo), (X, Wo), Gm, C]
875 const auto in_n_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
876 in_n_hip_wip_gm_c_grid_desc,
882 make_pass_through_transform(NumGroupsToMerge),
889 sequence<5>{},
890 sequence<6>{}));
891
892 // Input tensor transformation, part 3.
893 // [0, 1, 2, 3, 4 5 6] -> [0, 1]
894 // [N, Y, Ho, X, Wo, Gm, C] -> [(N*Ho*Wo), (Gm*Y*X*C)]
895 const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
896 in_n_y_ho_x_wo_gm_c_grid_desc,
897 make_tuple(make_merge_transform(make_tuple(Y_, X_, NumGroupsToMerge, C_)),
901
902 return make_tuple(
903 out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
904 }
905 else
906 {
907 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
908 in_grid_desc,
915
916 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
917 in_n_hip_wip_c_grid_desc,
926
927 const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
928 in_n_y_ho_x_wo_c_grid_desc,
933
934 return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
935 }
936 }
937
938 template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
940 {
941 const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
942 const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
943 const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
944
945 // B: input tensor comes in K_N
946 if constexpr(NumGroupsToMerge > 1)
947 {
948 // Output tensor transformation
949 // [0, 1, 2] -> [0, 1]
950 // [(N*Do*Ho*Wo), Gm, K] -> [(N*Do*Ho*Wo), (K*Gm)]
951 const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
952 out_grid_desc,
954 make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
957
958 // Input tensor transformation, part 1.
959 // [N, Di, Hi, Wi, Gm, C] -> [N, Dip, Hip, Wip, Gm, C]
960 const auto in_n_dip_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
961 in_grid_desc,
966 make_pass_through_transform(NumGroupsToMerge),
969 sequence<1>{},
970 sequence<2>{},
971 sequence<3>{},
972 sequence<4>{},
973 sequence<5>{}),
975 sequence<1>{},
976 sequence<2>{},
977 sequence<3>{},
978 sequence<4>{},
979 sequence<5>{}));
980
981 // Input tensor transformation, part 2.
982 // [N, Zip, Hip, Wip, Gm, C] -> [N, (Z, Zo), (Y, Wo), (X, Wo), Gm, C]
983 const auto in_n_z_do_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
984 in_n_dip_hip_wip_gm_c_grid_desc,
992 make_pass_through_transform(NumGroupsToMerge),
995 sequence<1>{},
996 sequence<2>{},
997 sequence<3>{},
998 sequence<4>{},
999 sequence<5>{}),
1004 sequence<7>{},
1005 sequence<8>{}));
1006
1007 // Input tensor transformation, part 3.
1008 // [0, 1, 2, 3, 4, 5, 6, 7, 8] -> [0, 1]
1009 // [N, Z, Do, Y, Ho, X, Wo, Gm, C] -> [(N*Do*Ho*Wo), (Z*Y*X*Gm*C)]
1010 const auto in_gemm_k_gemm_n_grid_desc = transform_tensor_descriptor(
1011 in_n_z_do_y_ho_x_wo_gm_c_grid_desc,
1012 make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, NumGroupsToMerge, C_)),
1016
1017 return make_tuple(
1018 out_gemm_k_gemm_m_grid_desc, in_gemm_k_gemm_n_grid_desc, wei_grid_desc);
1019 }
1020 else
1021 {
1022 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
1023 in_grid_desc,
1029 make_tuple(
1031 make_tuple(
1033
1034 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
1035 in_n_hip_wip_c_grid_desc,
1044 make_tuple(
1050 sequence<7>{}));
1051
1052 const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
1053 in_n_y_ho_x_wo_c_grid_desc,
1058
1059 return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
1060 }
1061 }
1062
1063 IndexType G_, N_;
1064 IndexType Di_, Hi_, Wi_;
1065 IndexType Do_, Ho_, Wo_;
1066 IndexType Z_, Y_, X_;
1067 IndexType K_, C_;
1072};
1073
1074} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
ConvolutionSpecialization
Definition convolution_specialization.hpp:11
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1565
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition coordinate_transform.hpp:1594
CK_TILE_HOST auto make_in_grid_desc() const
Definition tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp:437
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
Definition tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp:752
CK_TILE_HOST constexpr TransformConvBwdWeightToGemm()
Definition tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp:99
CK_TILE_HOST TransformConvBwdWeightToGemm(const TransformConvBwdWeightToGemmBase &transform_conv_fwd_to_gemm_base)
Definition tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp:102
CK_TILE_HOST auto make_out_grid_desc() const
Definition tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp:413
CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType &a_g_n_c_wis_lengths, const ConvDimsType &b_g_k_c_xs_lengths, const ConvDimsType &c_g_n_k_wos_lengths, const ConvSpatialDimsType &conv_filter_strides, const ConvSpatialDimsType &conv_filter_dilations, const ConvSpatialDimsType &input_left_pads, const ConvSpatialDimsType &input_right_pads)
Definition tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp:136
CK_TILE_HOST auto make_wei_grid_desc() const
Definition tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp:463
Definition tile/core/container/sequence.hpp:49