17template <
typename FloatA,
20 typename AThreadDesc_TK0_TM0_TM1_TK1,
21 typename BThreadDesc_TK0_TN0_TN1_TK1,
22 typename CThreadDesc_TM0_TM1_TN0_TN1,
26 typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
27 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
28 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
34 static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
35 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
36 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
37 "wrong! Desc should be known at compile-time");
43 static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
47 template <
typename ABuffer,
53 __device__
static void Run(
const ABuffer& a_buf,
63 "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
69 "wrong! inconsistent type");
74 constexpr auto TK = TKLengths{}[I0];
75 constexpr auto TM0 = TMLengths{}[I0];
76 constexpr auto TM1 = TMLengths{}[I1];
77 constexpr auto TN0 = TNLengths{}[I0];
78 constexpr auto TN1 = TNLengths{}[I1];
90 AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
93 BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
96 CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
116template <
typename FloatA,
119 typename AThreadDesc_TK0_TM0_TM1_TK1,
120 typename BThreadDesc_TK0_TN0_TN1_TK1,
121 typename CThreadDesc_TM0_TM1_TN0_TN1,
125 typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
126 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
127 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
133 static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
134 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
135 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
136 "wrong! Desc should be known at compile-time");
142 static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
146 template <
typename ABuffer,
152 __device__
static void Run(
const ABuffer& a_buf,
154 const BBuffer& b_buf,
162 "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
168 "wrong! inconsistent type");
173 constexpr index_t TK0 = TKLengths{}[I0];
174 constexpr index_t TK1 = TKLengths{}[I1];
175 constexpr index_t TM0 = TMLengths{}[I0];
176 constexpr index_t TM1 = TMLengths{}[I1];
177 constexpr index_t TN0 = TNLengths{}[I0];
178 constexpr index_t TN1 = TNLengths{}[I1];
194 AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
198 BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
209 CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
213 a_vec.template AsType<a_vector_t>()[I0],
214 b_vec.template AsType<b_vector_t>()[I0],
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto to_multi_index(const T &x)
Definition array_multi_index.hpp:28
__device__ void inner_product(const TA &a, const TB &b, TC &c)
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
__device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
Definition threadwise_contraction_dl.hpp:131
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition threadwise_contraction_dl.hpp:152
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition threadwise_contraction_dl.hpp:53
__device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1()
Definition threadwise_contraction_dl.hpp:32
Definition is_known_at_compile_time.hpp:14
Definition functional2.hpp:33
Definition dtype_vector.hpp:10