15template <
typename Problem_,
typename Policy_>
19 template <
typename PipelineProblem_,
typename GemmPolicy_>
29 static constexpr index_t kBlockSize = Problem::kBlockSize;
31 static constexpr index_t MPerBlock = BlockGemmShape::kM;
32 static constexpr index_t NPerBlock = BlockGemmShape::kN;
33 static constexpr index_t KPerBlock = BlockGemmShape::kK;
35 static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
38 static constexpr index_t MWarp = config.template at<1>();
39 static constexpr index_t NWarp = config.template at<2>();
46 static constexpr index_t KPack = WarpGemm::kKPerThread;
53 using Traits = GemmTraits_<Problem, Policy>;
75 constexpr auto a_block_outer_dstr_encoding =
84 a_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
86 return a_block_dstr_encode;
101 a_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
103 return a_block_dstr_encode;
116 a_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
118 return a_block_dstr_encode;
127 constexpr auto b_block_outer_dstr_encoding =
135 b_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
137 return b_block_dstr_encode;
151 b_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
153 return b_block_dstr_encode;
165 b_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
166 return b_block_dstr_encode;
183 c_block_outer_dstr_encoding,
typename WarpGemm::CWarpDstrEncoding{});
185 return c_block_dstr_encode;
197 c_block_outer_dstr_encoding,
typename WarpGemm::CWarpDstrEncoding{});
199 return c_block_dstr_encode;
204 template <
typename CBlockTensor,
typename ABlockTensor,
typename BBlockTensor>
206 const ABlockTensor& a_block_tensor,
207 const BBlockTensor& b_block_tensor)
const
209 static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
210 std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
211 std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
218 .get_static_tile_distribution_encoding())>>,
219 "A distribution is wrong!");
223 .get_static_tile_distribution_encoding())>>,
224 "B distribution is wrong!");
228 .get_static_tile_distribution_encoding())>>,
229 "C distribution is wrong!");
231 using AWarpDstr =
typename WarpGemm::AWarpDstr;
232 using BWarpDstr =
typename WarpGemm::BWarpDstr;
233 using CWarpDstr =
typename WarpGemm::CWarpDstr;
235 using AWarpTensor =
typename WarpGemm::AWarpTensor;
236 using BWarpTensor =
typename WarpGemm::BWarpTensor;
237 using CWarpTensor =
typename WarpGemm::CWarpTensor;
239 constexpr auto a_warp_y_lengths =
240 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
241 constexpr auto b_warp_y_lengths =
242 to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
243 constexpr auto c_warp_y_lengths =
244 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
256 AWarpTensor a_warp_tensor;
257 a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
263 BWarpTensor b_warp_tensor;
264 b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
268 CWarpTensor c_warp_tensor;
269 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
274 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
277 c_block_tensor.set_y_sliced_thread_data(
280 c_warp_tensor.get_thread_buffer());
291 AWarpTensor a_warp_tensor;
293 a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
298 BWarpTensor b_warp_tensor;
300 b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
305 CWarpTensor c_warp_tensor;
307 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
312 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
315 c_block_tensor.set_y_sliced_thread_data(
318 c_warp_tensor.get_thread_buffer());
338 c_block_outer_dstr_encoding,
typename WarpGemm::CWarpDstrEncoding{});
341 return c_block_tensor;
354 c_block_outer_dstr_encoding,
typename WarpGemm::CWarpDstrEncoding{});
357 return c_block_tensor;
362 template <
typename ABlockTensor,
typename BBlockTensor>
364 const BBlockTensor& b_block_tensor)
const
367 operator()(c_block_tensor, a_block_tensor, b_block_tensor);
368 return c_block_tensor;
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
@ KMN
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:12
@ MNK
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:13
Definition block_gemm_areg_breg_creg_v2.hpp:17
remove_cvref_t< typename Traits::BDataType > BDataType
Definition block_gemm_areg_breg_creg_v2.hpp:60
static CK_TILE_DEVICE constexpr auto MakeBBlockDistributionEncode()
Definition block_gemm_areg_breg_creg_v2.hpp:123
static constexpr index_t NIterPerWarp
Definition block_gemm_areg_breg_creg_v2.hpp:65
remove_cvref_t< Problem_ > Problem
Definition block_gemm_areg_breg_creg_v2.hpp:50
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_gemm_areg_breg_creg_v2.hpp:325
remove_cvref_t< typename Traits::ADataType > ADataType
Definition block_gemm_areg_breg_creg_v2.hpp:59
static constexpr bool UseDefaultScheduler
Definition block_gemm_areg_breg_creg_v2.hpp:69
typename Traits::BlockGemmShape BlockGemmShape
Definition block_gemm_areg_breg_creg_v2.hpp:56
remove_cvref_t< Policy_ > Policy
Definition block_gemm_areg_breg_creg_v2.hpp:51
remove_cvref_t< typename Traits::CDataType > CDataType
Definition block_gemm_areg_breg_creg_v2.hpp:61
static CK_TILE_DEVICE constexpr auto MakeCBlockDistributionEncode()
Definition block_gemm_areg_breg_creg_v2.hpp:171
static constexpr index_t MIterPerWarp
Definition block_gemm_areg_breg_creg_v2.hpp:64
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockTensor &a_block_tensor, const BBlockTensor &b_block_tensor) const
Definition block_gemm_areg_breg_creg_v2.hpp:205
static constexpr index_t NWarp
Definition block_gemm_areg_breg_creg_v2.hpp:68
static CK_TILE_DEVICE constexpr auto MakeABlockDistributionEncode()
Definition block_gemm_areg_breg_creg_v2.hpp:71
typename Traits::WarpGemm WarpGemm
Definition block_gemm_areg_breg_creg_v2.hpp:55
CK_TILE_DEVICE auto operator()(const ABlockTensor &a_block_tensor, const BBlockTensor &b_block_tensor) const
Definition block_gemm_areg_breg_creg_v2.hpp:363
GemmTraits_< Problem, Policy > Traits
Definition block_gemm_areg_breg_creg_v2.hpp:53
static constexpr index_t MWarp
Definition block_gemm_areg_breg_creg_v2.hpp:67
static constexpr index_t KIterPerWarp
Definition block_gemm_areg_breg_creg_v2.hpp:63
static constexpr auto BlockGemmLoopOrder
Definition block_gemm_areg_breg_creg_v2.hpp:57
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192