tensor_adaptor_coordinate.hpp Source File

tensor_adaptor_coordinate.hpp Source File#

Composable Kernel: tensor_adaptor_coordinate.hpp Source File
tensor_adaptor_coordinate.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
15
16namespace ck_tile {
17
18template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
20{
21 static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::size();
22 static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
23
27
28 public:
30
32 : idx_hidden_{idx_hidden}
33 {
34 }
35
36 CK_TILE_HOST_DEVICE constexpr auto get_top_index() const
37 {
38 return get_container_subset(idx_hidden_, TopDimensionHiddenIds{});
39 }
40
42 {
43 return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{});
44 }
45
46 CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const { return idx_hidden_; }
47
48 CK_TILE_HOST_DEVICE constexpr auto& get_hidden_index() { return idx_hidden_; }
49
50 //
52};
53
54template <typename Adaptor, typename TopIndex>
55CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor,
56 const TopIndex& idx_top)
57{
58 static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
59 "wrong! # of dimension inconsistent");
60
61 constexpr index_t ntransform = Adaptor::get_num_of_transform();
62 constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
63 constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
64 constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
65
66 multi_index<ndim_hidden> idx_hidden;
67
68 // initialize visible index
69 set_container_subset(idx_hidden, top_dim_ids, idx_top);
70
71 // calculate hidden index
72 static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
73 auto itran = itran_p1 - number<1>{};
74 const auto& tran = adaptor.get_transforms().at(itran);
75 constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
76 constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
77
78 const auto idx_up = get_container_subset(idx_hidden, dims_up);
79
80 multi_index<dims_low.size()> idx_low;
81
82 tran.calculate_lower_index(idx_low, idx_up);
83
84 set_container_subset(idx_hidden, dims_low, idx_low);
85 });
86
87 return tensor_adaptor_coordinate<ndim_hidden,
88 remove_cvref_t<decltype(bottom_dim_ids)>,
89 remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
90}
91
92template <bool JudgeDoTransforms = true,
93 typename Adaptor,
94 typename AdaptorCoord,
95 typename TopIndex,
96 typename BottomIndex>
97CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
98 AdaptorCoord& coord,
99 const TopIndex& idx_diff_top,
100 BottomIndex& idx_diff_bottom)
101{
102 constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
103 constexpr index_t ndim_top = Adaptor::get_num_of_top_dimension();
104 // constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
105 constexpr index_t ntransform = Adaptor::get_num_of_transform();
106
107 // static_assert(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, "");
108
109 // judge whether calculation of lower diff is needed for each transform
110 // use index_t for boolean type
111 auto do_transforms = make_zero_multi_index<ntransform>();
112
113 if constexpr(JudgeDoTransforms)
114 {
115 auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
116
117 // decide do_transform by checkout non-zero index diff components
118 multi_index<ndim_top> non_zero_diff_pick_top;
119
121 [&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); });
122
124 is_non_zero_diff, Adaptor::get_top_dimension_hidden_ids(), non_zero_diff_pick_top);
125
126 static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
127 constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
128 constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
129
130 const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
131
132 multi_index<dims_low.size()> non_zero_diff_pick_low;
133
134 // if any of upper index diff components is non-zero, then
135 // 1) Need to do this transform
136 // 2) all components of lower index diff will assume to be non-zero and need to be
137 // computed
138 const bool idx_diff_up_has_non_zero = container_reduce(
139 non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
140
141 do_transforms(itran) = idx_diff_up_has_non_zero;
142
143 static_for<0, dims_low.size(), 1>{}(
144 [&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
145
146 set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
147 });
148 }
149 else
150 {
151 static_for<ntransform - 1, -1, -1>{}([&](auto itran) { do_transforms(itran) = 1; });
152 }
153
154 // this is what needs to be calculated
155 auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
156
157 // initialize top index diff
158 set_container_subset(idx_diff_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_diff_top);
159
160 // this is what needs to be updated
161 auto& idx_hidden = coord.get_hidden_index();
162
163 // update top index
164 auto idx_hidden_pick_top =
165 get_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids());
166
167 idx_hidden_pick_top += idx_diff_top;
168
169 set_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_hidden_pick_top);
170
171 // update rest of hidden index
172 static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
173 if(do_transforms[itran])
174 {
175 const auto& tran = adaptor.get_transforms().at(itran);
176 constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
177 constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
178
179 const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
180 auto idx_low = get_container_subset(idx_hidden, dims_low);
181 const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
182
183 multi_index<dims_low.size()> idx_diff_low;
184
185 tran.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up_new);
186
187 set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
188 set_container_subset(idx_hidden, dims_low, idx_low);
189 }
190 });
191
192 // set bottom index diff
193 idx_diff_bottom =
194 get_container_subset(idx_diff_hidden, Adaptor::get_bottom_dimension_hidden_ids());
195}
196
197template <bool JudgeDoTransforms = true, typename Adaptor, typename AdaptorCoord, typename TopIndex>
198CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
199 AdaptorCoord& coord,
200 const TopIndex& idx_diff_top)
201{
202 constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
203
205
206 move_tensor_adaptor_coordinate<JudgeDoTransforms>(adaptor, coord, idx_diff_top, tmp);
207}
208
209template <typename Adaptor, typename AdaptorCoord>
210CK_TILE_HOST_DEVICE constexpr bool
212 const AdaptorCoord& coord)
213{
214 bool valid = true;
215
216 constexpr index_t ntransform = Adaptor::get_num_of_transform();
217
218 const auto& idx_hidden = coord.get_hidden_index();
219
220 static_for<ntransform - 1, -1, -1>{}([&adaptor, &idx_hidden, &valid](auto itran) {
221 const auto tran = adaptor.get_transforms().at(itran);
222
223 // check validity, only if current transformation does not always has a valid mapping
224 if constexpr(!decltype(tran)::is_valid_upper_index_always_mapped_to_valid_lower_index())
225 {
226 const auto idx_up = get_container_subset(
227 idx_hidden, Adaptor::get_upper_dimension_hidden_idss().at(itran));
228
229 // Comment: using valid = valid && .. will result in weird control flow in ISA
230 valid &= tran.is_valid_upper_index_mapped_to_valid_lower_index(idx_up);
231 }
232 });
233
234 return valid;
235}
236
237template <typename Adaptor, typename AdpatorCoord>
238CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
239 const AdpatorCoord& coord)
240{
241 // check top index
242 const auto& idx_top = coord.get_top_index();
243
244 bool is_top_index_valid = true;
245
246 static_for<0, Adaptor::get_num_of_dimension(), 1>{}(
247 [&is_top_index_valid, &idx_top, &adaptor](auto i) {
248 is_top_index_valid =
249 is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.get_length(i));
250 });
251
252 // check other hidden index
253 return is_top_index_valid &&
255}
256
257} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition tile/core/container/container_helper.hpp:198
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition tensor_adaptor_coordinate.hpp:97
CK_TILE_HOST_DEVICE constexpr void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition tile/core/container/container_helper.hpp:420
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition tile/core/container/container_helper.hpp:389
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
Definition tile/core/container/multi_index.hpp:26
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition tensor_adaptor_coordinate.hpp:55
CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor &adaptor, const AdaptorCoord &coord)
Definition tensor_adaptor_coordinate.hpp:211
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor &adaptor, const AdpatorCoord &coord)
Definition tensor_adaptor_coordinate.hpp:238
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition tile/core/utility/functional.hpp:43
Definition tensor_adaptor_coordinate.hpp:20
multi_index< ndim_top_ > TopIndex
Definition tensor_adaptor_coordinate.hpp:26
static constexpr index_t ndim_top_
Definition tensor_adaptor_coordinate.hpp:22
CK_TILE_HOST_DEVICE constexpr const auto & get_hidden_index() const
Definition tensor_adaptor_coordinate.hpp:46
multi_index< NDimHidden > HiddenIndex
Definition tensor_adaptor_coordinate.hpp:24
HiddenIndex idx_hidden_
Definition tensor_adaptor_coordinate.hpp:51
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate()=default
CK_TILE_HOST_DEVICE constexpr auto get_top_index() const
Definition tensor_adaptor_coordinate.hpp:36
CK_TILE_HOST_DEVICE constexpr auto & get_hidden_index()
Definition tensor_adaptor_coordinate.hpp:48
CK_TILE_HOST_DEVICE constexpr auto get_bottom_index() const
Definition tensor_adaptor_coordinate.hpp:41
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate(const HiddenIndex &idx_hidden)
Definition tensor_adaptor_coordinate.hpp:31
multi_index< ndim_bottom_ > BottomIndex
Definition tensor_adaptor_coordinate.hpp:25
static constexpr index_t ndim_bottom_
Definition tensor_adaptor_coordinate.hpp:21