warp_gemm_smfmac_impl.hpp Source File

warp_gemm_smfmac_impl.hpp Source File#

Composable Kernel: warp_gemm_smfmac_impl.hpp Source File
warp_gemm_smfmac_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7namespace ck_tile {
8
9template <typename WarpGemmAttribute_>
11{
13
14 static constexpr index_t kM = WarpGemmAttribute::kM;
15 static constexpr index_t kN = WarpGemmAttribute::kN;
16 static constexpr index_t kK = WarpGemmAttribute::kK;
21 static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
22
23 using ADataType = typename WarpGemmAttribute::ADataType;
24 using BDataType = typename WarpGemmAttribute::BDataType;
25 using CDataType = typename WarpGemmAttribute::CDataType;
26
27 using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
28 using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
29 using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
30
34
38
39 CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
40 {
41 return WarpGemmAttribute_::get_num_of_access();
42 }
43
44 //----------------------------------------------------------------------------------------------
52 template <typename AVec>
54 {
55 int32_t idx = 0b11101110;
56
57 static_for<0, 2, 1>{}([&](auto i) {
58 ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
59 int32_t non_zero_pos = 0;
60
61 static_for<0, 3, 1>{}([&](auto j) {
62 if(a_vec[i * 4 + j] != 0.0f)
63 {
64 nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
65 idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
66 idx |= j << 2 * (i * 2 + non_zero_pos);
67 ++non_zero_pos;
68 }
69 });
70 a_vec[i * 2] = nonzero_elems[0];
71 a_vec[i * 2 + 1] = nonzero_elems[1];
72 });
73
74 return idx;
75 }
76
77 template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
79 operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
80 {
84 constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
85
86 using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
87 using AVecCompressed =
88 ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
89 using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
90 using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
91
92 constexpr auto I0 = number<0>{};
93
94 auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
95 const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
96 auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
97
98 const int32_t idx = compress_a(a_vec);
99
100 // @TODO can we simply set a_vec_pruned to a_vec[0:3]?
101 const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};
102
103 // c_vec += a_vec * b_vec[idx]
104 WarpGemmAttribute{}(c_vec, a_vec_pruned, b_vec, idx, bool_constant<post_nop_>{});
105
106 c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
107 }
108};
109
110} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
constexpr bool is_similiar_distributed_tensor_v
Definition static_distributed_tensor.hpp:230
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t int32_t
Definition integer.hpp:10
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
signed int int32_t
Definition stdint.h:123
Definition warp_gemm_smfmac_impl.hpp:11
CK_TILE_DEVICE void operator()(CTensor &c, const ATensor &a, const BTensor &b, bool_constant< post_nop_ >={}) const
Definition warp_gemm_smfmac_impl.hpp:79
typename WarpGemmAttribute::BWarpDstrEncoding BWarpDstrEncoding
Definition warp_gemm_smfmac_impl.hpp:28
static_distributed_tensor< CDataType, CWarpDstr > CWarpTensor
Definition warp_gemm_smfmac_impl.hpp:37
static_distributed_tensor< BDataType, BWarpDstr > BWarpTensor
Definition warp_gemm_smfmac_impl.hpp:36
remove_cvref_t< decltype(make_static_tile_distribution(BWarpDstrEncoding{}))> BWarpDstr
Definition warp_gemm_smfmac_impl.hpp:32
remove_cvref_t< decltype(make_static_tile_distribution(CWarpDstrEncoding{}))> CWarpDstr
Definition warp_gemm_smfmac_impl.hpp:33
static_distributed_tensor< ADataType, AWarpDstr > AWarpTensor
Definition warp_gemm_smfmac_impl.hpp:35
typename WarpGemmAttribute::CWarpDstrEncoding CWarpDstrEncoding
Definition warp_gemm_smfmac_impl.hpp:29
CK_TILE_DEVICE int32_t compress_a(AVec &a_vec) const
Compress A vector for 2:4 structured sparsity instruction by moving all non-zero elements into lower ...
Definition warp_gemm_smfmac_impl.hpp:53
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_access()
Definition warp_gemm_smfmac_impl.hpp:39
remove_cvref_t< decltype(make_static_tile_distribution(AWarpDstrEncoding{}))> AWarpDstr
Definition warp_gemm_smfmac_impl.hpp:31
typename WarpGemmAttribute::AWarpDstrEncoding AWarpDstrEncoding
Definition warp_gemm_smfmac_impl.hpp:27
remove_cvref_t< WarpGemmAttributeSmfmac< WarpGemmAttributeSmfmacImplF16F16F32M32N32K16< WGAttrCtlEnum::Default_ > > > WarpGemmAttribute
Definition warp_gemm_smfmac_impl.hpp:12
Definition static_distributed_tensor.hpp:21
Definition tile/core/utility/functional.hpp:43