reference_grouped_conv_bwd_weight.hpp Source File

reference_grouped_conv_bwd_weight.hpp Source File#

Composable Kernel: reference_grouped_conv_bwd_weight.hpp Source File
reference_grouped_conv_bwd_weight.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cstdlib>
7#include <thread>
8
9#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14template <ck_tile::index_t NDimSpatial,
15 typename InDataType,
16 typename WeiDataType,
17 typename OutDataType>
18CK_TILE_HOST void
21 const HostTensor<OutDataType>& output,
22 std::vector<ck_tile::long_index_t> conv_strides,
23 std::vector<ck_tile::long_index_t> conv_dilations,
24 std::vector<ck_tile::long_index_t> in_left_pads,
25 std::vector<ck_tile::long_index_t>)
26{
27 if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
28 weight.get_num_of_dimension() == NDimSpatial + 3 &&
29 output.get_num_of_dimension() == NDimSpatial + 3))
30 {
31 throw std::runtime_error("wrong! inconsistent dimension");
32 }
33
34 if constexpr(NDimSpatial == 1)
35 {
36 auto func = [&](auto g, auto k, auto c, auto x) {
37 float v_acc = 0;
38
39 for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
40 {
41 for(std::size_t wo = 0; wo < output.get_lengths()[3]; ++wo)
42 {
43 auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
44 static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
45 static_cast<ck_tile::long_index_t>(in_left_pads[0]);
46
47 if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
48 {
49 InDataType v_in = input(g, n, c, wi);
50 OutDataType v_out = output(g, n, k, wo);
51 v_acc += ck_tile::type_convert<float>(v_out) *
53 }
54 }
55 }
56 OutDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
57 weight(g, k, c, x) = v_acc_converted;
58 };
59
61 weight.get_lengths()[0],
62 weight.get_lengths()[1],
63 weight.get_lengths()[2],
64 weight.get_lengths()[3])(std::thread::hardware_concurrency());
65 }
66 else if constexpr(NDimSpatial == 2)
67 {
68 auto func = [&](auto g, auto k, auto c, auto y, auto x) {
69 float v_acc = 0;
70
71 for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
72 {
73 for(std::size_t ho = 0; ho < output.get_lengths()[3]; ++ho)
74 {
75 auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
76 static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
77 static_cast<ck_tile::long_index_t>(in_left_pads[0]);
78
79 for(std::size_t wo = 0; wo < output.get_lengths()[4]; ++wo)
80 {
81 auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
82 static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
83 static_cast<ck_tile::long_index_t>(in_left_pads[1]);
84
85 if(hi >= 0 &&
87 wi >= 0 &&
89 {
90 InDataType v_in = input(g, n, c, hi, wi);
91 OutDataType v_out = output(g, n, k, ho, wo);
92
93 v_acc += ck_tile::type_convert<float>(v_out) *
95 }
96 }
97 }
98 }
99 WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
100 weight(g, k, c, y, x) = v_acc_converted;
101 };
102
104 weight.get_lengths()[0],
105 weight.get_lengths()[1],
106 weight.get_lengths()[2],
107 weight.get_lengths()[3],
108 weight.get_lengths()[4])(std::thread::hardware_concurrency());
109 }
110 else if constexpr(NDimSpatial == 3)
111 {
112 auto func = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
113 float v_acc = 0;
114
115 for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
116 {
117 for(std::size_t do_ = 0; do_ < output.get_lengths()[3]; ++do_)
118 {
119 auto di = static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
120 static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
121 static_cast<ck_tile::long_index_t>(in_left_pads[0]);
122 for(std::size_t ho = 0; ho < output.get_lengths()[4]; ++ho)
123 {
124 auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
125 static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
126 static_cast<ck_tile::long_index_t>(in_left_pads[1]);
127 for(std::size_t wo = 0; wo < output.get_lengths()[5]; ++wo)
128 {
129 auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
130 static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
131 static_cast<ck_tile::long_index_t>(in_left_pads[2]);
132 if(di >= 0 &&
134 hi >= 0 &&
136 wi >= 0 &&
138 {
139 InDataType v_in = input(g, n, c, di, hi, wi);
140 OutDataType v_out = output(g, n, k, do_, ho, wo);
141
142 v_acc += ck_tile::type_convert<float>(v_out) *
144 }
145 }
146 }
147 }
148 }
149 WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
150 weight(g, k, c, z, y, x) = v_acc_converted;
151 };
152
154 weight.get_lengths()[0],
155 weight.get_lengths()[1],
156 weight.get_lengths()[2],
157 weight.get_lengths()[3],
158 weight.get_lengths()[4],
159 weight.get_lengths()[5])(std::thread::hardware_concurrency());
160 }
161 else
162 {
163 throw std::runtime_error(
164 "Ref_conv_bwd_weight: number of dimensions must be between 1 and 3.");
165 }
166}
167} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
int64_t long_index_t
Definition integer.hpp:11
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST void reference_grouped_conv_bwd_weight(const HostTensor< InDataType > &input, HostTensor< WeiDataType > &weight, const HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >)
Definition reference_grouped_conv_bwd_weight.hpp:19
Definition tile/host/host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition tile/host/host_tensor.hpp:396