block_reduce.hpp Source File

block_reduce.hpp Source File#

Composable Kernel: block_reduce.hpp Source File
block_reduce.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7#include <tuple>
8
9// This file is not support cross warp reduce
10namespace ck_tile {
11
12/*
13 * TODO: block_tile_reduce_sync() currently has a limitation
14 * Y dim must have at least one dim not been reduced
15 */
16// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
17template <typename AccDistributedTensor_,
18 typename ReduceFunc,
19 bool WithBroadcast = true,
20 bool CrossWarp = true>
21CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
22 const ReduceFunc& reduce_func,
25{
26 using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
27 using DstrEncode = typename Dstr::DstrEncode;
28 using DstrEncodeDetail = typename DstrEncode::detail;
29
30 constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
31 constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
32
33 constexpr index_t idim_p_lane = NDimP - 1;
34
35 const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution());
36 const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
37
38 constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
39
40 // loop over thread data
42 auto v_local = acc_tensor.get_thread_buffer()[i];
43
44 // cross-lane reduce for replication
45 // only reduce on R dimension correspond to lane
46 // (lane id maps to this R dimension)
47 static_for<0, NDimR, 1>{}([&](auto idim_r) {
48 // FIXME: nasty to use does_p_own_r_
49 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
50 {
51 constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
52
53 constexpr index_t lid_over_rid_derivative =
54 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
55
56 static_assert(is_power_of_two_integer(r_length),
57 "wrong! only support power of 2 reduction");
58
59 constexpr index_t nstage = integer_log2_floor(r_length);
60
61 // reduction sweep forward
62 static_for<0, nstage, 1>{}([&](auto istage) {
63 if constexpr(CrossWarp)
64 {
65 constexpr index_t lid_delta =
66 lid_over_rid_derivative * (1 << (nstage - istage - 1));
67
68 // pull data from remote lane
69 const auto v_remote = warp_shuffle_down(v_local, lid_delta);
70
71 // reduce
72 v_local = reduce_func(v_local, v_remote);
73 }
74 else
75 {
76 // pull data from remote lane
77 const auto v_swapped_regs = warp_shuffle_down_pair(v_local);
78 // reduce
79 v_local = reduce_func(v_swapped_regs.at(0), v_swapped_regs.at(1));
80 }
81 });
82 }
83 });
84
85 if constexpr(WithBroadcast)
86 {
87 // cross-lane broadcast for replication
88 // only broadcast on R dimension correspond to lane
89 // (lane id maps to this R dimension)
90 static_for<0, NDimR, 1>{}([&](auto idim_r) {
91 // FIXME: nasty to use does_p_own_r_
92 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
93 {
94 const index_t r_id = rs_idx[idim_r];
95
96 constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
97
98 constexpr index_t lid_over_rid_derivative =
99 DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r];
100
101 static_assert(is_power_of_two_integer(r_length),
102 "wrong! only support power of 2 reduction");
103
104 constexpr index_t nstage = integer_log2_floor(r_length);
105
106 // broadcast sweep backward
107 static_for<0, nstage, 1>{}([&](auto istage) {
108 // do I hold reduced data?
109 const bool do_i_hold_reduced_data = r_id < (1 << istage);
110
111 constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage);
112
113 // pull data from remote lane
114 const auto v_remote = warp_shuffle_up(v_local, lid_delta);
115
116 // decide whether to update local data with remote data
117 v_local = do_i_hold_reduced_data ? v_local : v_remote;
118 });
119 }
120 });
121 }
122
123 acc_tensor.get_thread_buffer()(i) = v_local;
124 });
125}
126
127/*
128 * this version is faster, using xor to do reduce, no need broadcast anymore
129 * TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
130 */
131template <typename AccDistributedTensor_, typename ReduceFunc>
132CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
133 const ReduceFunc& reduce_func)
134{
135 using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
136 using DstrEncode = typename Dstr::DstrEncode;
137 using DstrEncodeDetail = typename DstrEncode::detail;
138
139 constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
140 constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
141
142 constexpr index_t idim_p_lane = NDimP - 1;
143
144 constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
145
146 // loop over thread data
148 auto v_local = acc_tensor.get_thread_buffer()[i];
149
150 // cross-lane reduce for replication
151 // only reduce on R dimension correspond to lane
152 // (lane id maps to this R dimension)
153 static_for<0, NDimR, 1>{}([&](auto idim_r) {
154 // FIXME: nasty to use does_p_own_r_
155 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
156 {
157 constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
158
159 constexpr index_t lid_over_rid_derivative =
160 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
161
162 static_assert(is_power_of_two_integer(r_length),
163 "wrong! only support power of 2 reduction");
164
165 constexpr index_t nstage = integer_log2_floor(r_length);
166
167 // reduction sweep forward
168 static_for<0, nstage, 1>{}([&](auto istage) {
169 // xor
170 index_t src_lane =
171 __lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
172
173 // pull data from remote lane
174 const auto v_remote = warp_shuffle(v_local, src_lane);
175
176 // reduce
177 v_local = reduce_func(v_local, v_remote);
178 });
179 }
180 });
181
182 acc_tensor.get_thread_buffer()(i) = v_local;
183 });
184}
185
186// FIXME: this is for 2D to 1D reduce only, need to support n-D
187template <typename AccDistributedTensor_,
188 typename InDistributedTensor_,
189 index_t... InReduceDims,
190 typename ReduceFunc>
191CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
192 const InDistributedTensor_& in_tensor,
194 const ReduceFunc& reduce_func)
195{
196 constexpr auto I0 = number<0>{};
197 constexpr auto I1 = number<1>{};
198
199#if 0
200 constexpr auto in_reduce_dims = sequence<InReduceDims...>{};
201
202 constexpr index_t ndim_in = InDistributedTensor_::get_num_of_dimension();
203 constexpr index_t ndim_in_reduce = in_reduce_dims.size();
204 constexpr index_t ndim_in_free = ndim_in - ndim_in_reduce;
205
206 constexpr auto in_free_dims_arr = [&] {
207 array<bool, ndim_free> is_free_dims{true};
208
209 for(index_t i = 0; i < ndim_reduce; i++)
210 {
211 is_free_dims(in_reduce_dims[i]) = false;
212 }
213
214 array<index_t, ndim_free> in_free_dims{-1};
215
216 index_t cnt = 0;
217
218 for(index_t i = 0; i < ndim_in; i++)
219 {
220 if(is_free_dims[i])
221 {
222 in_free_dims(cnt) = i;
223
224 cnt++
225 }
226 }
227
228 return is_free_dims;
229 }();
230
231 constexpr auto in_free_dims = TO_SEQUENCE(is_free_dims_arr, ndim_in_free);
232#else
233
234 constexpr auto spans = InDistributedTensor_::get_distributed_spans();
235
236 // in-thread reduction
237 // FIXME: hard coded to be 2D to 1D reduction
238 sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
239 constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
240
241 auto acc = acc_tensor[acc_dstr_idx];
242
243 // FIXME
244 sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
245 constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
246
247 const auto in = in_tensor[in_dstr_idx];
248
249 acc = reduce_func(acc, in);
250 });
251
252 acc_tensor(acc_dstr_idx) = acc;
253 });
254#endif
255}
256
257/*
258 * TODO: block_tile_reduce() currently has a limitation
259 * Y dim must have at least one dim not been reduced
260 */
261template <typename AccDataType_,
262 typename InDistributedTensor_,
263 index_t... InReduceDims,
264 typename ReduceFunc,
265 typename InDataType_>
266CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
267 sequence<InReduceDims...> in_reduce_dims,
268 const ReduceFunc& reduce_func,
269 const InDataType_& reduce_init)
270{
271 using InDataType = typename InDistributedTensor_::DataType;
272 using AccDataType = remove_cvref_t<AccDataType_>;
273
274 static_assert(std::is_same_v<InDataType, remove_cvref_t<InDataType_>>, "wrong!");
275
276 // declare acc_tensor
277 constexpr auto acc_dstr =
279 InDistributedTensor_::get_tile_distribution().get_static_tile_distribution_encoding(),
281
282 auto acc_tensor = make_static_distributed_tensor<AccDataType>(acc_dstr);
283
284 // init acc_tensor
285 tile_elementwise_inout([&](auto& acc) { acc = type_convert<AccDataType>(reduce_init); },
286 acc_tensor);
287
288 // warp reduce
289 block_tile_reduce(acc_tensor, in_tensor, in_reduce_dims, reduce_func);
290
291 return acc_tensor;
292}
293
294// this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
295// this version only support in/acc/out datatypes are the same
296// this version will call thread/warp+sync in one function call
297//
298template <typename InDistributedTensor_>
300{
302 using InDataType = typename InDistributedTensor::DataType;
303
305 : t(t_), reduce_init(reduce_init_)
306 {
307 }
308
310 {
311 using ReduceDim = sequence<1>; // hard coded
312 constexpr auto acc_dstr =
314 InDistributedTensor::get_tile_distribution()
315 .get_static_tile_distribution_encoding(),
316 ReduceDim{}));
317
318 auto dst_ = make_static_distributed_tensor<InDataType>(acc_dstr);
319 // init acc_tensor
320 tile_elementwise_inout([&](auto& x_) { x_ = type_convert<InDataType>(reduce_init); }, dst_);
321 return dst_;
322 }
323
324 // return number of pixels each lane need to reduce
326 {
327 constexpr auto spans = InDistributedTensor::get_distributed_spans();
328 }
329
330 // Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
331 // this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
332 // internally
333 // For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
334 // element , and the first element is always ignored For simplicity, will always try from
335 // right-to-left to find alone which Y dim to split
336 template <typename ReduceFunc,
337 typename ReduceSyncFunc,
338 typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
339 CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func,
340 const ReduceSyncFunc& reduce_sync_func,
341 ReducePacksPerXDim = {}) const
342 {
343 constexpr auto spans = InDistributedTensor::get_distributed_spans();
344
345 constexpr auto row_y_unpacks = [&]() {
346 constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
347 constexpr auto row_y_size =
348 reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
349 constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
350
351 static_assert(row_y_size % row_y_packs == 0);
352
353 constexpr auto row_y_slice_size = row_y_size / row_y_packs;
354
355 constexpr auto slice_info = slice_sequence(row_y_lengths, number<row_y_slice_size>{});
356 constexpr auto unpacks = slice_info[number<1>{}];
357 return unpacks;
358 }();
359
360 auto acc_tensor = MakeDstBlockTile();
361
362 // in-thread reduction
363 // FIXME: hard coded to be 2D to 1D reduction
364 sweep_tile_span(spans[number<0>{}], [&](auto dstr_idx_i0) {
365 constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
366
367 auto acc = acc_tensor[acc_dstr_idx];
368
370 spans[number<1>{}],
371 [&](auto... dstr_idx_i1) {
372 acc = reduce_func(acc, t[make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
373 },
374 row_y_unpacks);
375
376 acc_tensor(acc_dstr_idx) = acc;
377 });
378
379 // TODO: always use xor to do cross-lane reduce
380 block_tile_reduce_xor_sync(acc_tensor, reduce_sync_func);
381
382 return acc_tensor;
383 }
384
385 template <typename ReduceFunc>
386 CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func) const
387 {
388 return operator()(reduce_func, reduce_func);
389 }
390
393};
394
395// deduction guide
396template <typename T>
397CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D<T>;
398
399} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE_EXTERN
Definition config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition tile_distribution_encoding.hpp:762
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition tile_distribution.hpp:22
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:132
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
Definition tile/core/numeric/math.hpp:462
CK_TILE_DEVICE T warp_shuffle_up(const T &v_local, uint32_t lane_delta)
Definition utility.hpp:31
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F &f, Unpacks={})
Definition sweep_tile.hpp:37
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition utility.hpp:78
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE T warp_shuffle_down(const T &v_local, uint32_t lane_delta)
Definition utility.hpp:48
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T &v_local)
Definition utility.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
Definition tile/core/numeric/math.hpp:455
constexpr auto slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition tile/core/container/sequence.hpp:1249
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE constexpr index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition tile/core/container/sequence.hpp:982
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T &, const typename T::DataType &) -> BlockReduce2D< T >
Definition block_reduce.hpp:300
CK_TILE_HOST_DEVICE constexpr auto MakeDstBlockTile() const
Definition block_reduce.hpp:309
remove_cvref_t< InDistributedTensor_ > InDistributedTensor
Definition block_reduce.hpp:301
CK_TILE_HOST_DEVICE constexpr auto get_reduce_length_y() const
Definition block_reduce.hpp:325
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc &reduce_func, const ReduceSyncFunc &reduce_sync_func, ReducePacksPerXDim={}) const
Definition block_reduce.hpp:339
InDataType reduce_init
Definition block_reduce.hpp:392
InDistributedTensor t
Definition block_reduce.hpp:391
typename InDistributedTensor::DataType InDataType
Definition block_reduce.hpp:302
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor &t_, const InDataType &reduce_init_)
Definition block_reduce.hpp:304
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc &reduce_func) const
Definition block_reduce.hpp:386
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
#define TO_SEQUENCE(a, n)
Definition to_sequence.hpp:10