threadwise_gemm_dlops_v3.hpp Source File

threadwise_gemm_dlops_v3.hpp Source File#

Composable Kernel: threadwise_gemm_dlops_v3.hpp Source File
threadwise_gemm_dlops_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
5#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
6
7#include "common_header.hpp"
8#include "math.hpp"
9
10namespace ck {
11
12// C[M, N] += transpose(A[K, M]) * B[K, N]
13// Element of matrix can be vectorized data
14// Assume:
15// 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at
16// compile-time
17// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
18template <typename FloatA,
19 typename FloatB,
20 typename FloatC,
21 typename AThreadDesc_E1_K_E2,
22 typename BThreadDesc_E1_N_Ho_Wo_E2,
23 typename CThreadDesc_K_N_Ho_Wo,
24 typename enable_if<AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
25 BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
26 CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
27 bool>::type = false>
29{
30
31 template <typename ABuffer,
32 typename AOriginIdx,
33 typename BBuffer,
34 typename BOriginIdx,
35 typename CBuffer,
36 typename COriginIdx>
37 __device__ static void Run(const ABuffer& a_buf,
38 AOriginIdx,
39 const BBuffer& b_buf,
40 BOriginIdx,
41 CBuffer& c_buf,
42 COriginIdx)
43 {
44
45 static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
46 BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
47 CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
48 "wrong! Desc should be known at compile-time");
49
53 "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
54
55 static_assert(
59 "wrong! inconsistent type");
60
61 constexpr auto I0 = Number<0>{};
62 constexpr auto I1 = Number<1>{};
63 constexpr auto I2 = Number<2>{};
64 constexpr auto I3 = Number<3>{};
65
66 constexpr auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0);
67 constexpr auto K = AThreadDesc_E1_K_E2{}.GetLength(I1);
68 constexpr auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2);
69
70 constexpr auto Ho = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
71 constexpr auto Wo = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
72
73 constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
74 constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
75 constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
76
77 if constexpr((Ho % 2 == 0) && (Wo % 2 == 0))
78 {
79 constexpr auto SubHW = 2;
80
81 static_for<0, K, 1>{}([&](auto k) {
82 static_for<0, Ho, SubHW>{}([&](auto h) {
83 static_for<0, Wo, SubHW>{}([&](auto w) {
84 static_for<0, E1, 1>{}([&](auto e1) {
85 static_for<0, E2, 1>{}([&](auto e2) {
86 constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
87 a_origin_idx + make_tuple(e1, k, e2));
88
89 constexpr index_t b0_offset =
90 BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
91 b_origin_idx + make_tuple(e1, 0, h, w, e2));
92
93 constexpr index_t b1_offset =
94 BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
95 b_origin_idx + make_tuple(e1, 0, h, w + 1, e2));
96
97 constexpr index_t b2_offset =
98 BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
99 b_origin_idx + make_tuple(e1, 0, h + 1, w, e2));
100
101 constexpr index_t b3_offset =
102 BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
103 b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2));
104
105 constexpr index_t c0_offset =
106 CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
107 make_tuple(k, 0, h, w));
108
109 constexpr index_t c1_offset =
110 CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
111 c_origin_idx + make_tuple(k, 0, h, w + 1));
112
113 constexpr index_t c2_offset =
114 CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
115 c_origin_idx + make_tuple(k, 0, h + 1, w));
116
117 constexpr index_t c3_offset =
118 CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
119 c_origin_idx + make_tuple(k, 0, h + 1, w + 1));
120
122 b_buf[Number<b0_offset>{}],
123 b_buf[Number<b1_offset>{}],
124 b_buf[Number<b2_offset>{}],
125 b_buf[Number<b3_offset>{}],
126 c_buf(Number<c0_offset>{}),
127 c_buf(Number<c1_offset>{}),
128 c_buf(Number<c2_offset>{}),
129 c_buf(Number<c3_offset>{}));
130 });
131 });
132 });
133 });
134 });
135 }
136 else
137 {
138
139 static_for<0, K, 1>{}([&](auto k) {
140 static_for<0, Ho, 1>{}([&](auto h) {
141 static_for<0, Wo, 1>{}([&](auto w) {
142 static_for<0, E1, 1>{}([&](auto e1) {
143 static_for<0, E2, 1>{}([&](auto e2) {
144 constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
145 a_origin_idx + make_tuple(e1, k, e2));
146
147 constexpr index_t b_offset =
148 BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
149 b_origin_idx + make_tuple(e1, 0, h, w, e2));
150
151 constexpr index_t c_offset =
152 CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
153 make_tuple(k, 0, h, w));
154
156 b_buf[Number<b_offset>{}],
157 c_buf(Number<c_offset>{}));
158 });
159 });
160 });
161 });
162 });
163 }
164 }
165};
166
167} // namespace ck
168#endif
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__device__ void amd_assembly_outer_product_1x4(float a, float b0, float b1, float b2, float b3, float &c0, float &c1, float &c2, float &c3)
Definition amd_inline_asm.hpp:106
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto to_multi_index(const T &x)
Definition array_multi_index.hpp:28
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void inner_product(const TA &a, const TB &b, TC &c)
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition threadwise_gemm_dlops_v3.hpp:29
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition threadwise_gemm_dlops_v3.hpp:37
Definition is_known_at_compile_time.hpp:14
Definition type.hpp:177
Definition functional2.hpp:33