warp_gemm_attribute_smfmac_impl.hpp Source File

warp_gemm_attribute_smfmac_impl.hpp Source File#

Composable Kernel: warp_gemm_attribute_smfmac_impl.hpp Source File
warp_gemm_attribute_smfmac_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck_tile {
9
10// fp16 2:4 structured sparsity
11
12template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
14{
15 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
19 using CDataType = float;
20
24
25 static constexpr index_t kM = 32;
26 static constexpr index_t kN = 32;
27 static constexpr index_t kK = 16;
28
29 static constexpr index_t kAMBlock = 1;
30 static constexpr index_t kBNBlock = 1;
31
32 static constexpr index_t kAMLane = 32;
33 static constexpr index_t kBNLane = 32;
34 static constexpr index_t kABKLane = 2;
35 static constexpr index_t kABKPerLane = 8;
36
37 static constexpr index_t kCMLane = 2;
38 static constexpr index_t kCNLane = 32;
39 static constexpr index_t kCM0PerLane = 4;
40 static constexpr index_t kCM1PerLane = 4;
41
42 static constexpr index_t CompressionRatio = 2;
43
44 // c_vec += a_vec * b_vec[idx]
45 template <bool post_nop_ = false>
47 const AVecType& a_vec,
48 const BVecType& b_vec,
49 const int32_t& idx,
50 bool_constant<post_nop_> = {}) const
51 {
52#if defined(__gfx94_) or defined(__gfx95_)
53 c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0);
54#else
55 ck_tile::ignore = c_vec;
56 ck_tile::ignore = a_vec;
57 ck_tile::ignore = b_vec;
58 ck_tile::ignore = idx;
59#endif
60 }
61};
62
63template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
65{
66 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
70 using CDataType = float;
71
75
76 static constexpr index_t kM = 16;
77 static constexpr index_t kN = 16;
78 static constexpr index_t kK = 32;
79
80 static constexpr index_t kAMBlock = 1;
81 static constexpr index_t kBNBlock = 1;
82
83 static constexpr index_t kAMLane = 16;
84 static constexpr index_t kBNLane = 16;
85 static constexpr index_t kABKLane = 4;
86 static constexpr index_t kABKPerLane = 8;
87
88 static constexpr index_t kCMLane = 4;
89 static constexpr index_t kCNLane = 16;
90 static constexpr index_t kCM0PerLane = 1;
91 static constexpr index_t kCM1PerLane = 4;
92
93 static constexpr index_t CompressionRatio = 2;
94
95 // c_vec += a_vec * b_vec[idx]
96 template <bool post_nop_ = false>
98 const AVecType& a_vec,
99 const BVecType& b_vec,
100 const int32_t& idx,
101 bool_constant<post_nop_> = {}) const
102 {
103#if defined(__gfx94_) or defined(__gfx95_)
104 c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0);
105#else
106 ck_tile::ignore = c_vec;
107 ck_tile::ignore = a_vec;
108 ck_tile::ignore = b_vec;
109 ck_tile::ignore = idx;
110#endif
111 }
112};
113
114} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
WGAttrCtlEnum
Definition warp_gemm_attribute_mfma_impl.hpp:15
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
_Float16 fp16_t
Definition half.hpp:110
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
int32_t int32_t
Definition integer.hpp:10
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
int32_t index_t
Definition integer.hpp:9
Definition warp_gemm_attribute_smfmac_impl.hpp:65
ext_vector_t< fp16_t, 8 > BVecType
Definition warp_gemm_attribute_smfmac_impl.hpp:73
int32_t IdxDataType
Definition warp_gemm_attribute_smfmac_impl.hpp:69
static constexpr index_t kM
Definition warp_gemm_attribute_smfmac_impl.hpp:76
static constexpr index_t kAMLane
Definition warp_gemm_attribute_smfmac_impl.hpp:83
float CDataType
Definition warp_gemm_attribute_smfmac_impl.hpp:70
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_smfmac_impl.hpp:86
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_smfmac_impl.hpp:90
static constexpr index_t CompressionRatio
Definition warp_gemm_attribute_smfmac_impl.hpp:93
fp16_t ADataType
Definition warp_gemm_attribute_smfmac_impl.hpp:67
ext_vector_t< fp16_t, 4 > AVecType
Definition warp_gemm_attribute_smfmac_impl.hpp:72
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, const int32_t &idx, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_smfmac_impl.hpp:97
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_smfmac_impl.hpp:91
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_smfmac_impl.hpp:66
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_smfmac_impl.hpp:80
static constexpr index_t kN
Definition warp_gemm_attribute_smfmac_impl.hpp:77
static constexpr index_t kCMLane
Definition warp_gemm_attribute_smfmac_impl.hpp:88
fp16_t BDataType
Definition warp_gemm_attribute_smfmac_impl.hpp:68
static constexpr index_t kCNLane
Definition warp_gemm_attribute_smfmac_impl.hpp:89
static constexpr index_t kBNLane
Definition warp_gemm_attribute_smfmac_impl.hpp:84
ext_vector_t< float, 4 > CVecType
Definition warp_gemm_attribute_smfmac_impl.hpp:74
static constexpr index_t kK
Definition warp_gemm_attribute_smfmac_impl.hpp:78
static constexpr index_t kABKLane
Definition warp_gemm_attribute_smfmac_impl.hpp:85
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_smfmac_impl.hpp:81
Definition warp_gemm_attribute_smfmac_impl.hpp:14
static constexpr index_t kN
Definition warp_gemm_attribute_smfmac_impl.hpp:26
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_smfmac_impl.hpp:35
static constexpr index_t kK
Definition warp_gemm_attribute_smfmac_impl.hpp:27
ext_vector_t< fp16_t, 4 > AVecType
Definition warp_gemm_attribute_smfmac_impl.hpp:21
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_smfmac_impl.hpp:40
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, const int32_t &idx, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_smfmac_impl.hpp:46
static constexpr index_t kABKLane
Definition warp_gemm_attribute_smfmac_impl.hpp:34
static constexpr index_t kM
Definition warp_gemm_attribute_smfmac_impl.hpp:25
ext_vector_t< float, 16 > CVecType
Definition warp_gemm_attribute_smfmac_impl.hpp:23
float CDataType
Definition warp_gemm_attribute_smfmac_impl.hpp:19
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_smfmac_impl.hpp:39
fp16_t BDataType
Definition warp_gemm_attribute_smfmac_impl.hpp:17
static constexpr index_t kAMLane
Definition warp_gemm_attribute_smfmac_impl.hpp:32
static constexpr index_t kCNLane
Definition warp_gemm_attribute_smfmac_impl.hpp:38
static constexpr index_t kCMLane
Definition warp_gemm_attribute_smfmac_impl.hpp:37
static constexpr index_t kBNLane
Definition warp_gemm_attribute_smfmac_impl.hpp:33
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_smfmac_impl.hpp:30
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_smfmac_impl.hpp:29
int32_t IdxDataType
Definition warp_gemm_attribute_smfmac_impl.hpp:18
ext_vector_t< fp16_t, 8 > BVecType
Definition warp_gemm_attribute_smfmac_impl.hpp:22
fp16_t ADataType
Definition warp_gemm_attribute_smfmac_impl.hpp:16
static constexpr index_t CompressionRatio
Definition warp_gemm_attribute_smfmac_impl.hpp:42
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_smfmac_impl.hpp:15