26__host__ __device__
constexpr bool HasSlice(T&&)
28 return is_detected<is_slice, T>::value;
30template <
typename... Ts>
31__host__ __device__
constexpr bool HasSlice(Tuple<Ts...>&&)
33 return (HasSlice(Ts{}) || ...);
43template <
typename... Ts,
typename SlicedShape>
44__host__ __device__
constexpr auto GetSlicedShape(
const Tuple<Ts...>& idxs,
45 const SlicedShape&
shape)
50 constexpr auto num_i = Number<i>{};
53 if constexpr(!detail::HasSlice(tuple_element_t<i.value, Tuple<Ts...>>{}))
67 const auto& dim = size(
shape.At(num_i));
68 const auto val = idxs.At(num_i).range(dim);
77 Number<Tuple<Ts...>::Size()>{});
79 return UnrollNestedTuple<0, 1>(new_shape);
89template <
typename T,
typename Shape>
90__host__ __device__
constexpr auto GenerateMultipleFreeze(T idx,
const Shape&
shape)
96 const auto dim = unrolled_shape.At(Number<i>{});
97 const auto dim_idx = idx % dim;
101 Number<
decltype(unrolled_shape)::Size()>{});
111template <
typename... Ts,
typename Shape>
112__host__ __device__
constexpr auto GenerateSliceTransforms(
const Tuple<Ts...>& idx,
118 constexpr auto num_i = Number<i>{};
121 return GenerateSliceTransforms(idx.At(num_i),
shape.At(num_i));
126 const auto from = idx.At(num_i).from_;
127 const auto dim = size<num_i>(
shape);
128 const auto range = idx.At(num_i).range(dim);
134 return GenerateMultipleFreeze(idx.At(num_i),
shape.At(num_i));
137 Number<Tuple<Ts...>::Size()>{});
142template <index_t i,
typename LowerIndex>
143__host__ __device__
constexpr auto GetSequenceVal(
const ck::Freeze<LowerIndex>&)
149template <index_t i,
typename LowLength,
typename SliceBegin,
typename SliceEnd>
150__host__ __device__
constexpr auto GetSequenceVal(
const ck::Slice<LowLength, SliceBegin, SliceEnd>&)
152 return Sequence<i>{};
156__host__ __device__
constexpr auto GenerateUpperDims(
const Tuple<>&)
161template <
index_t i,
typename... Transforms>
162__host__ __device__
constexpr auto GenerateUpperDims(
const Tuple<Transforms...>& transforms)
164 constexpr auto num_transforms = Tuple<Transforms...>::Size();
166 const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
167 if constexpr(
is_same_v<
decltype(current_elem),
const Sequence<>>)
169 const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
175 const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
180template <
typename... Ts,
typename Shape,
typename UnrolledDescriptor>
181__host__ __device__
constexpr auto GenerateSlicedDescriptor(
const Tuple<Ts...>& idx,
183 const UnrolledDescriptor& flatten_desc)
187 const auto transforms = GenerateSliceTransforms(idx,
shape);
188 using TransformsTupleType =
decltype(transforms);
190 const auto lower_dims =
191 generate_tuple([&](
auto i) {
return Sequence<i.value>{}; }, Number<old_shape_dims>{});
192 const auto upper_dims =
decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
210 typename ElementType,
212 typename UnrolledDescriptorType>
219 is_scalar_type<ElementType>::value,
221 typename scalar_type<std::remove_const_t<ElementType>>::type>;
224 static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
225 BufferAddressSpace == MemoryTypeEnum ::Vgpr);
232 multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
235 static_assert(
IsDynamicBuffer,
"Wrong BufferAddressSpace for register.");
240 multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
243 static_assert(!
IsDynamicBuffer,
"Wrong BufferAddressSpace for register.");
257 template <
typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
258 __host__ __device__
auto operator[](
const Tuple<Ts...>& idx)
261 const auto&
shape = layout_.GetShape();
262 auto new_shape = detail::GetSlicedShape(idx,
shape);
264 const auto& flatten_desc = layout_.GetUnrolledDescriptor();
265 auto new_desc = detail::GenerateSlicedDescriptor(idx,
shape, flatten_desc);
266 const auto new_layout =
269 base_offset_ -= new_layout(make_tuple(Number<0>{}));
273 template <
typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
274 __host__ __device__
auto operator()(
const Tuple<Ts...>& idx)
279 template <
typename... Idxs, enable_if_t<detail::HasSlice(Tuple<Idxs...>{}),
bool> =
false>
280 __host__ __device__
auto operator()(Idxs... idxs)
291 template <
typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
296 const index_t offset = layout_(idx) + base_offset_;
297 return buffer_[offset];
303 UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
307 UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
308 return buffer_[Number<index_offset + base_offset>{}];
312 template <
typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
318 template <
typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}),
bool> =
false>
330 template <
typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
335 const index_t offset = layout_(idx) + base_offset_;
336 return buffer_(offset);
342 UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
346 UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
347 return buffer_(Number<index_offset + base_offset>{});
351 template <
typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
357 template <
typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}),
bool> =
false>
370 return layout_.GetMergedNestingDescriptor();
380 __host__ __device__
constexpr auto&
GetBuffer() {
return buffer_; }
381 __host__ __device__
constexpr auto&
GetBuffer()
const {
return buffer_; }
395 template <
typename MultiIdxOffsets>
398 multi_idx_offset_ = multi_idx_offset;
399 base_offset_ += layout_(multi_idx_offset);
405 using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
409 using StaticBufferType = std::conditional_t<
410 is_scalar_type<ElementType>::value,
411 StaticBuffer<BufferAddressSpace,
415 StaticBufferTupleOfVector<BufferAddressSpace,
418 scalar_type<std::remove_const_t<ElementType>>::vector_size,
419 scalar_type<std::remove_const_t<ElementType>>::vector_size,
422 using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
424 const Layout<Shape, UnrolledDescriptorType> layout_;
__host__ __device__ constexpr const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition layout_utils.hpp:431
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
__host__ __device__ constexpr auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition tuple_helper.hpp:52
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<> &element)
Definition tuple_helper.hpp:120
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
const GenericPointer< typename T::ValueType > & pointer
Definition pointer.h:1514
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
__host__ __device__ constexpr auto GetElementSpaceSize() const
Definition layout.hpp:297
__host__ __device__ constexpr const Layout< Shape, UnrolledDescriptorType > & GetLayout() const
Definition tensor.hpp:246
__host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
Apply multi index offset on the tensor.
Definition tensor.hpp:396
__host__ __device__ constexpr auto & GetBuffer() const
Definition tensor.hpp:381
decltype(Layout< Shape, UnrolledDescriptorType >{ Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()) ElementSpaceSize
Definition tensor.hpp:216
static constexpr bool IsDynamicBuffer
Definition tensor.hpp:224
__host__ __device__ auto operator[](const Tuple< Ts... > &idx)
Get the new sliced tensor.
Definition tensor.hpp:258
__host__ __device__ constexpr auto & GetBuffer()
Definition tensor.hpp:380
std::conditional_t< is_scalar_type< ElementType >::value, ElementType, typename scalar_type< std::remove_const_t< ElementType > >::type > TensorElementType
Definition tensor.hpp:218
__host__ __device__ constexpr auto & GetMultiIdxOffsets() const
Get multi index offset to the data.
Definition tensor.hpp:388
static constexpr MemoryTypeEnum TensorBufferAddressSpace
Definition tensor.hpp:223
__host__ __device__ Tensor()=delete
__host__ __device__ constexpr Tensor(const Layout< Shape, UnrolledDescriptorType > &layout)
Definition tensor.hpp:238
__host__ __device__ TensorElementType * GetPointer() const
Get pointer to the data.
Definition tensor.hpp:378
std::size_t GetElementSpaceSize() const
Definition library/utility/host_tensor.hpp:810
__host__ __device__ constexpr auto GetMergedNestingDescriptor()
Get descriptor with all nested dimensions merged.
Definition tensor.hpp:368
__host__ __device__ constexpr Tensor(ElementType *pointer, const Layout< Shape, UnrolledDescriptorType > &layout)
Definition tensor.hpp:228
__host__ __device__ constexpr const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition tensor_utils.hpp:162
AddressSpaceEnum MemoryTypeEnum
Memory type, allowed members:
Definition tensor_utils.hpp:30
constexpr auto make_tensor(ElementType *pointer, const Layout< Shape, UnrolledDescriptorType > &layout)
Make tensor function.
Definition tensor_utils.hpp:112