bfloat16.hpp Source File

bfloat16.hpp Source File#

Composable Kernel: bfloat16.hpp Source File
bfloat16.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
9#if CK_TILE_USE_LLVM_BUILTIN_BF16
10#include <hip/hip_bfloat16.h>
11#endif
12#include <stdint.h>
13
14#pragma once
15
16namespace ck_tile {
17
19{
20 standard = 0, // rtn
24 rta_asm, // round to nearest away
25};
26
27template <bf16_rounding_mode rounding =
30
31template <bf16_rounding_mode rounding =
34
36constexpr float bf16_to_float_raw(uint16_t x);
37
39constexpr double bf16_to_double_raw(uint16_t x);
40
41#if CK_TILE_USE_CUSTOM_DATA_TYPE
42// HIP use __hip_bfloat16 as struct
43struct alignas(2) bfloat16_t
44{
45 using raw_type = uint16_t;
46 raw_type data;
47
49 static constexpr bfloat16_t bit_cast(raw_type x)
50 {
51 bfloat16_t y;
52 y.data = x;
53 return y;
54 }
55
56 // constructor
57 constexpr bfloat16_t() : data() {}
58
59 // construct from float
61 explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
62
63 // construct from double
65 explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
66
67 // construct from int
69 explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
70
71 // construct from unsigned int
73 explicit constexpr bfloat16_t(const unsigned int& x)
74 : data(float_to_bf16_raw(static_cast<float>(x)))
75 {
76 }
77
78 // cast to float
80 explicit constexpr operator float() const { return bf16_to_float_raw(data); }
81
82 // cast to float
84 explicit constexpr operator double() const { return bf16_to_double_raw(data); }
85
86 // cast to int
88 explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
89
90 // internal access
92 constexpr raw_type& get() { return data; }
93
95 constexpr raw_type get() const { return data; }
96};
97template <typename>
98struct native_t;
99
100template <>
102{
103 using type = ushort;
104};
105using bf16_t = bfloat16_t;
106using bf16_raw_t = typename bf16_t::raw_type;
107#else
108#if CK_TILE_USE_LLVM_BUILTIN_BF16
109using bfloat16_t = __bf16;
110#else
111using bfloat16_t = ushort;
112#endif
115#endif
116// round to nearest
119{
121 if(~bits & 0x7f800000)
122 {
123 // When the exponent bits are not all 1s, then the value is zero, normal,
124 // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
125 // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
126 // This causes the bfloat16's mantissa to be incremented by 1 if the 16
127 // least significant bits of the float mantissa are greater than 0x8000,
128 // or if they are equal to 0x8000 and the least significant bit of the
129 // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
130 // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
131 // has the value 0x7f, then incrementing it causes it to become 0x00 and
132 // the exponent is incremented by one, which is the next higher FP value
133 // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
134 // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
135 // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
136 // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
137 // incrementing it causes it to become an exponent of 0xFF and a mantissa
138 // of 0x00, which is Inf, the next higher value to the unrounded value.
139 bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even
140 }
141 else if(bits & 0xffff)
142 {
143 // When all of the exponent bits are 1, the value is Inf or NaN.
144 // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
145 // mantissa bit. Quiet NaN is indicated by the most significant mantissa
146 // bit being 1. Signaling NaN is indicated by the most significant
147 // mantissa bit being 0 but some other bit(s) being 1. If any of the
148 // lower 16 bits of the mantissa are 1, we set the least significant bit
149 // of the bfloat16 mantissa, in order to preserve signaling NaN in case
150 // the bloat16's mantissa bits are all 0.
151 bits |= 0x10000; // Preserve signaling NaN
152 }
153 return uint16_t(bits >> 16);
154}
155
157constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
158
161{
162 union
163 {
164 float fp32;
165 uint32_t int32;
166 } u = {f};
167
168 static constexpr uint32_t FP32_NAN = 0x7fff0000;
169 static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
170
171#if defined(__GFX9__)
172 using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
173 uint32x2_t check_nan;
174#else
175 uint32_t check_nan;
176#endif
177 uint32_t tmp;
178 asm volatile("\n \
179 v_cmp_u_f32 %0, %2, %2 \n \
180 v_bfe_u32 %1, %2, 16, 1 \n \
181 v_add3_u32 %1, %2, %1, %3 \n \
182 v_cndmask_b32 %2, %1, %4, %0 \n \
183 v_lshrrev_b32 %2, 16, %2 \n \
184 "
185 : "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
186 : "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
187
188 return uint16_t(u.int32);
189}
190
191// TODO: do we need this on host?
194
197{
198 union
199 {
200 float fp32;
201 struct
202 {
203 uint16_t lo;
204 uint16_t hi;
205 };
206 } u = {f};
207
208 const uint32_t low_nan = 0x7fff;
209 const uint32_t hi_nan = 0x7fff0000;
210
211#if defined(__GFX9__)
212 using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
213 uint32x2_t check_nan;
214#else
215 uint32_t check_nan;
216#endif
217
218 asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
219 "v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
220 "v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
221 : [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
222 : [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
223
224 // Note: in above code snipet, we use hi 16 bit
225 return u.hi;
226}
227
228// Truncate instead of rounding, preserving SNaN
231{
233 return static_cast<uint16_t>(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff));
234}
235
236// Fast truncate instead of rounding, RTZ
239{
241 return static_cast<uint16_t>(bits >> 16);
242}
243
244template <bf16_rounding_mode rounding>
246{
247 if constexpr(rounding == bf16_rounding_mode::standard)
248 return float_to_bf16_rtn_raw(f);
249 else if constexpr(rounding == bf16_rounding_mode::standard_asm)
250 return float_to_bf16_rtn_asm(f);
251 else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
253 else if constexpr(rounding == bf16_rounding_mode::rta_asm)
254 return float_to_bf16_rta_asm(f);
255 else
256 return float_to_bf16_truc_raw(f);
257}
258
259template <bf16_rounding_mode rounding>
261{
262 return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
263}
264
266constexpr float bf16_to_float_raw(uint16_t x)
267{
268 union
269 {
270 uint32_t int32;
271 float fp32;
272 } u = {uint32_t(x) << 16};
273 return u.fp32;
274}
275
277constexpr double bf16_to_double_raw(uint16_t x)
278{
279 return static_cast<double>(bf16_to_float_raw(x));
280}
281
282template <bf16_rounding_mode rounding =
285{
286#if CK_TILE_USE_LLVM_BUILTIN_BF16
287 return static_cast<bfloat16_t>(f);
288#else
290#endif
291}
292
293template <bf16_rounding_mode rounding =
299
302
304constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
305
306template <bf16_rounding_mode rounding =
312
314constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
315
316template <class T>
317struct numeric;
318
319template <>
321{
322 // minimum finite value, or minimum positive normalized value for float
324 {
325 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
326 }
327
328 // minumum finite value
330 {
331 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
332 }
333
334 // maximum finite value
336 {
337 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
338 }
339
340 // difference between 1.0 and next value representable by float
342 {
343 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
344 }
345
346 // maximum rounding error
347 // maximum rounding error
348 // bin : f edcba 9876543210
349 // bits: s eeeeeeee mmmmmmm
350 // 0 01111110 0000000 (0.5)
351 //
353 {
354 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
355 }
356
357 // positive infinity value
359 {
360 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
361 }
362
363 // quiet NaN
365 {
366 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
367 }
368
369 // signaling NaN
371 {
372 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
373 }
374
375 // smallest positive subnormal value
377 {
378 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
379 }
381 {
382 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
383 }
384};
385
386template <>
388{
389 static constexpr int exp = 8;
390 static constexpr int mant = 7;
391 static constexpr int PackedSize = 1;
392};
393
394#if CK_TILE_USE_CUSTOM_DATA_TYPE
396#endif
397
398// math
401{
402 return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
403}
404
406bool isnan(const bfloat16_t& x)
407{
409 return (xx & 0x7FFF) > 0x7C00;
410}
411
414{
415 return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
416};
417
420{
421 return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
422};
423
425bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
426
428bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
429
430} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
Definition config.hpp:72
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr float bf16_to_float(bfloat16_t x)
Definition bfloat16.hpp:301
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant< rounding >={})
Definition bfloat16.hpp:284
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_truc_raw(float f)
Definition bfloat16.hpp:238
CK_TILE_HOST_DEVICE constexpr float bf16_to_float_raw(uint16_t x)
Definition bfloat16.hpp:266
_Float16 half_t
Definition half.hpp:111
ushort bfloat16_t
Definition bfloat16.hpp:111
bfloat16_t bf16_t
Definition bfloat16.hpp:113
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant< rounding >={})
Definition bfloat16.hpp:245
CK_TILE_HOST_DEVICE constexpr double bf16_to_double_raw(uint16_t x)
Definition bfloat16.hpp:277
_Float16 fp16_t
Definition half.hpp:110
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST uint16_t float_to_bf16_rta_asm(float f)
Definition bfloat16.hpp:193
CK_TILE_HOST_DEVICE constexpr half_t bf16_to_fp16(bfloat16_t x)
Definition bfloat16.hpp:314
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition bfloat16.hpp:413
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
Definition bfloat16.hpp:230
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant< rounding >={})
Definition bfloat16.hpp:260
bf16_rounding_mode
Definition bfloat16.hpp:19
@ truncate
Definition bfloat16.hpp:22
@ truncate_with_nan
Definition bfloat16.hpp:21
@ standard
Definition bfloat16.hpp:20
@ rta_asm
Definition bfloat16.hpp:24
@ standard_asm
Definition bfloat16.hpp:23
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition bfloat16.hpp:400
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_rtn_raw(float f)
Definition bfloat16.hpp:118
@ constant
Definition arch.hpp:51
uint16_t bf16_raw_t
Definition bfloat16.hpp:114
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition bfloat16.hpp:406
CK_TILE_HOST_DEVICE constexpr bfloat16_t double_to_bf16(double f, constant< rounding >={})
Definition bfloat16.hpp:295
CK_TILE_HOST constexpr uint16_t float_to_bf16_rtn_asm(float f)
Definition bfloat16.hpp:157
CK_TILE_HOST_DEVICE bfloat16_t constexpr fp16_to_bf16(half_t f, constant< rounding >={})
Definition bfloat16.hpp:308
CK_TILE_HOST_DEVICE constexpr double bf16_to_double(bfloat16_t x)
Definition bfloat16.hpp:304
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
uint32_t uint32x2_t
Definition vector_type.hpp:163
unsigned short uint16_t
Definition stdint.h:125
unsigned int uint32_t
Definition stdint.h:126
Definition tile/core/numeric/integral_constant.hpp:13
Definition vector_type.hpp:26
static CK_TILE_HOST_DEVICE constexpr bfloat16_t epsilon()
Definition bfloat16.hpp:341
static CK_TILE_HOST_DEVICE constexpr bfloat16_t max()
Definition bfloat16.hpp:335
static CK_TILE_HOST_DEVICE constexpr bfloat16_t quiet_NaN()
Definition bfloat16.hpp:364
static CK_TILE_HOST_DEVICE constexpr bfloat16_t infinity()
Definition bfloat16.hpp:358
static CK_TILE_HOST_DEVICE constexpr bfloat16_t round_error()
Definition bfloat16.hpp:352
static CK_TILE_HOST_DEVICE constexpr bfloat16_t min()
Definition bfloat16.hpp:323
static CK_TILE_HOST_DEVICE constexpr bfloat16_t lowest()
Definition bfloat16.hpp:329
static CK_TILE_HOST_DEVICE constexpr bfloat16_t zero()
Definition bfloat16.hpp:380
static CK_TILE_HOST_DEVICE constexpr bfloat16_t denorm_min()
Definition bfloat16.hpp:376
static CK_TILE_HOST_DEVICE constexpr bfloat16_t signaling_NaN()
Definition bfloat16.hpp:370
static constexpr int PackedSize
Definition bfloat16.hpp:391
static constexpr int exp
Definition bfloat16.hpp:389
static constexpr int mant
Definition bfloat16.hpp:390
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/numeric/numeric.hpp:18
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition tile/core/numeric/numeric.hpp:106