threadwise_tensor_slice_set.hpp Source File

threadwise_tensor_slice_set.hpp Source File#

Composable Kernel: threadwise_tensor_slice_set.hpp Source File
threadwise_tensor_slice_set.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace ck {
11
12// Assume:
13// 1. Desc is known at compile-time
14// 2. Buffer is StaticBuffer
15// 3. OriginIdx is known at compile-time
16// 4. use #-step
17template <typename Data,
18 typename Desc,
19 typename SliceLengths,
20 typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
22{
23 static constexpr index_t nDim = SliceLengths::Size();
24
26
27 template <typename OriginIdx, typename Buffer>
28 __device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const
29 {
30 static_assert(Desc::IsKnownAtCompileTime(),
31 "wrong! SrcDesc and DstDesc need to known at compile-time");
32
33 static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
34
36 "wrong! OriginIdx need to be known at compile-time");
37
38 // Desc is known at compile-time
39 constexpr auto desc = remove_cvref_t<Desc>{};
40
41 // OriginIdx is known at compile-time
42 constexpr auto origin_idx = to_multi_index(OriginIdx{});
43
44 static_ford<SliceLengths>{}([&](auto access_idx) {
45 constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx);
46
47 constexpr bool is_valid =
49
50 constexpr index_t offset = coord.GetOffset();
51
52 if constexpr(is_valid)
53 {
54 buf(Number<offset>{}) = initial_value;
55 }
56 });
57 }
58};
59
60} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
integral_constant< index_t, N > Number
Definition number.hpp:12
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto to_multi_index(const T &x)
Definition array_multi_index.hpp:28
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition threadwise_tensor_slice_set.hpp:22
__device__ void Run(const Desc &, const OriginIdx &, Buffer &buf, const Data &initial_value) const
Definition threadwise_tensor_slice_set.hpp:28
MultiIndex< nDim > Index
Definition threadwise_tensor_slice_set.hpp:25
static constexpr index_t nDim
Definition threadwise_tensor_slice_set.hpp:23
Definition is_known_at_compile_time.hpp:14
Definition functional3.hpp:97