math.hpp Source File

math.hpp Source File#

Composable Kernel: math.hpp Source File
tile/core/numeric/math.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10#include <type_traits>
11#include <stdint.h>
12#include <cmath>
13
14namespace ck_tile {
15
16template <typename Scale, Scale lhs>
18{
19 template <typename Right>
20 CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const -> decltype(lhs * rhs)
21 {
22 return lhs * rhs;
23 }
24};
25
26template <typename Scale>
27struct scales
28{
29 static_assert(std::is_copy_constructible_v<Scale>);
30
31 CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
32
33 template <typename Right>
34 CK_TILE_HOST_DEVICE constexpr auto
35 operator()(const Right& rhs) const -> decltype(std::declval<const Scale&>() * rhs)
36 {
37 return lhs_ * rhs;
38 }
39
40 private:
41 Scale lhs_;
42};
43
45template <typename Scale>
46__host__ __device__ scales(Scale) -> scales<Scale>;
47
48template <typename Left = void, typename Right = Left>
49struct plus
50{
51 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
52 const Right& rhs) const -> decltype(lhs + rhs)
53 {
54 return lhs + rhs;
55 }
56};
57
58template <>
59struct plus<void, void>
60{
61 template <typename Left, typename Right>
62 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
63 const Right& rhs) const -> decltype(lhs + rhs)
64 {
65 return lhs + rhs;
66 }
67};
68
70__host__ __device__ plus() -> plus<void, void>;
71
72template <typename Left = void, typename Right = Left>
73struct minus
74{
75 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
76 const Right& rhs) const -> decltype(lhs - rhs)
77 {
78 return lhs - rhs;
79 }
80};
81
82template <>
83struct minus<void, void>
84{
85 template <typename Left, typename Right>
86 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
87 const Right& rhs) const -> decltype(lhs - rhs)
88 {
89 return lhs - rhs;
90 }
91};
92
94__host__ __device__ minus() -> minus<void, void>;
95
96template <typename Left = void, typename Right = Left>
98{
99 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
100 const Right& rhs) const -> decltype(lhs * rhs)
101 {
102 return lhs * rhs;
103 }
104};
105
106template <>
107struct multiplies<void, void>
108{
109 template <typename Left, typename Right>
110 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
111 const Right& rhs) const -> decltype(lhs * rhs)
112 {
113 return lhs * rhs;
114 }
115};
116
118__host__ __device__ multiplies() -> multiplies<void, void>;
119
120template <typename T>
122{
123 CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
124};
125
126template <typename T>
128{
129 CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
130};
131
132template <typename T>
134{
135 CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const
136 {
137 static_assert(std::is_same<T, index_t>{} || std::is_same<T, int>{}, "wrong type");
138 return (a + b - number<1>{}) / b;
139 }
140};
141
142template <typename X, typename Y>
144{
145 return x / y;
146}
147
148template <typename X, typename Y>
150{
151 return (x + y - number<1>{}) / y;
152}
153
154template <typename X, typename Y>
156{
157 return y * integer_divide_ceil(x, y);
158}
159
160template <typename T>
161CK_TILE_HOST_DEVICE constexpr T max(T x)
162{
163 return x;
164}
165
166template <typename T>
167CK_TILE_HOST constexpr T max(T x, T y)
168{
169 return x > y ? x : y;
170}
171
172template <typename T>
173CK_TILE_DEVICE constexpr T max(T x, T y)
174{
175 return x > y ? x : y;
176}
177
178template <>
179CK_TILE_DEVICE constexpr float max(float x, float y)
180{
181 return __builtin_fmaxf(x, y); // can resultin v_max3_f32
182}
183
184template <>
185CK_TILE_DEVICE constexpr double max(double x, double y)
186{
187 return __builtin_fmax(x, y); // maybe still v_max3_f32
188}
189
190template <index_t X>
192{
193 return X > y ? X : y;
194}
195
196template <index_t Y>
198{
199 return x > Y ? x : Y;
200}
201
202template <typename X, typename... Ys>
203CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys)
204{
205 static_assert(sizeof...(Ys) > 0, "not enough argument");
206 return max(x, max(ys...));
207}
208
209template <typename T>
210CK_TILE_HOST_DEVICE constexpr T min(T x)
211{
212 return x;
213}
214
215template <typename T>
216CK_TILE_HOST constexpr T min(T x, T y)
217{
218 return x < y ? x : y;
219}
220
221template <typename T>
222CK_TILE_DEVICE constexpr T min(T x, T y)
223{
224 return x < y ? x : y;
225}
226
227template <>
228CK_TILE_DEVICE constexpr float min(float x, float y)
229{
230 return __builtin_fminf(x, y);
231}
232
233template <>
234CK_TILE_DEVICE constexpr double min(double x, double y)
235{
236 return __builtin_fmin(x, y);
237}
238
239template <index_t X>
241{
242 return X < y ? X : y;
243}
244
245template <index_t Y>
247{
248 return x < Y ? x : Y;
249}
250
251template <typename X, typename... Ys>
252CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys)
253{
254 static_assert(sizeof...(Ys) > 0, "not enough argument");
255 return min(x, min(ys...));
256}
257
258template <typename T>
259CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
260{
261 return min(max(x, lowerbound), upperbound);
262}
263
264CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); }
265CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); }
266
267// greatest common divisor, aka highest common factor
269{
270 if(x < 0)
271 {
272 return gcd(-x, y);
273 }
274 else if(y < 0)
275 {
276 return gcd(x, -y);
277 }
278 else if(x == y || x == 0)
279 {
280 return y;
281 }
282 else if(y == 0)
283 {
284 return x;
285 }
286 else if(x > y)
287 {
288 return gcd(x % y, y);
289 }
290 else
291 {
292 return gcd(x, y % x);
293 }
294}
295
296template <index_t X, index_t Y>
298{
299 constexpr auto r = gcd(X, Y);
300
301 return number<r>{};
302}
303
304template <typename X,
305 typename... Ys,
306 typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
307CK_TILE_HOST_DEVICE constexpr auto gcd(X x, Ys... ys)
308{
309 return gcd(x, gcd(ys...));
310}
311
312// least common multiple
313template <typename X, typename Y>
314CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
315{
316 return (x * y) / gcd(x, y);
317}
318
319template <typename X,
320 typename... Ys,
321 typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
322CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
323{
324 return lcm(x, lcm(ys...));
325}
326
327template <typename Left = void, typename Right = Left>
328struct equal
329{
330 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
331 const Right& rhs) const -> decltype(lhs == rhs)
332 {
333 return lhs == rhs;
334 }
335};
336
337template <>
339{
340 template <typename Left, typename Right>
341 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
342 const Right& rhs) const -> decltype(lhs == rhs)
343 {
344 return lhs == rhs;
345 }
346};
347
349__host__ __device__ equal() -> equal<void, void>;
350
351template <>
353{
354 CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const
355 {
356 return bit_cast<uint32_t>(lhs) == bit_cast<uint32_t>(rhs);
357 }
358};
359
360template <>
362{
363 CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const
364 {
365 return bit_cast<uint64_t>(lhs) == bit_cast<uint64_t>(rhs);
366 }
367};
368
369template <typename Left = void, typename Right = Left>
370struct less
371{
372 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
373 const Right& rhs) const -> decltype(lhs < rhs)
374 {
375 return lhs < rhs;
376 }
377};
378
379template <>
381{
382 template <typename Left, typename Right>
383 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
384 const Right& rhs) const -> decltype(lhs < rhs)
385 {
386 return lhs < rhs;
387 }
388};
389
391__host__ __device__ less() -> less<void, void>;
392
393template <typename Left = void, typename Right = Left>
395{
396 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
397 const Right& rhs) const -> decltype(lhs <= rhs)
398 {
399 return lhs <= rhs;
400 }
401};
402
403template <>
405{
406 template <typename Left, typename Right>
407 CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
408 const Right& rhs) const -> decltype(lhs <= rhs)
409 {
410 return lhs <= rhs;
411 }
412};
413
415__host__ __device__ less_equal() -> less_equal<void, void>;
416
417template <>
419{
420 CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const
421 {
422 return lhs < rhs || bit_cast<uint32_t>(lhs) == bit_cast<uint32_t>(rhs);
423 }
424};
425
426template <>
428{
429 CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const
430 {
431 return lhs < rhs || bit_cast<uint64_t>(lhs) == bit_cast<uint64_t>(rhs);
432 }
433};
434
436{
437 // TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
438 return 1 << (32 - clz(x - 1));
439}
440
441template <index_t X>
443{
444 constexpr index_t y = next_power_of_two(X);
445 return number<y>{};
446}
447
448template <index_t X>
450{
451 constexpr index_t y = next_power_of_two(X);
452 return number<y>{};
453}
454
456{
457 // TODO: x need to be 1 ~ 0x7fffffff
458 // __builtin_clz will produce unexpected result if x is 0;
459 return 31 - __builtin_clz(x);
460}
461
463{
464 // TODO: x need to be 1 ~ 0x7fffffff
465 return x == (1 << integer_log2_floor(x));
466}
467
468#ifndef C_LOG2E
469#define C_LOG2E 1.44269504088896340736 // log2(e)
470#endif
471
472template <typename T>
473struct log2e;
474
475template <>
477{
478 static constexpr double value = C_LOG2E;
479};
480
481template <>
483{
484 static constexpr float value = C_LOG2E;
485};
486
487template <typename T = double>
489
490template <typename T = double>
491constexpr T log2e_rcp_v = 1. / log2e<T>::value;
492
494float exp2(float x) { return exp2f(x); };
495
497float exp2(float x) { return std::exp2f(x); };
498
500{
501 return __builtin_amdgcn_sad_u16(x, y, acc);
502}
503
505{
507 uint32_t res;
508 asm volatile("v_sad_u32 %0, %1, %2, %3" : "=v"(res) : "v"(x), "v"(y), "v"(acc));
509 return res;
510}
511
513{
514 return (x > y ? (x - y) : (y - x)) + acc;
515}
516
518
519} // namespace ck_tile
520// blow function need data type pre-defined
525#ifndef __HIP_DEVICE_COMPILE__
526#include <cmath>
527#endif
528
529namespace ck_tile {
530#if CK_TILE_WORKAROUND_SWDEV_383542
531extern "C" CK_TILE_DEVICE float __ocml_native_recip_f32(float);
532#endif
533
534// math functions for the host, some are implemented by calling C++ std functions
535
536CK_TILE_HOST float abs(float x) { return std::abs(x); };
537
538CK_TILE_HOST double abs(double x) { return std::abs(x); };
539
541{
542 int8_t sgn = x >> (8 - 1);
543
544 return (x ^ sgn) - sgn;
545};
546
548{
549 int32_t sgn = x >> (32 - 1);
550
551 return (x ^ sgn) - sgn;
552};
553
555{
557
558 uint16_t abs_xx = xx & 0x7fff;
559
560 fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
561
562 return abs_x;
563};
564
565#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
566CK_TILE_HOST int4_t abs(int4_t x)
567{
568 int4_t sgn = x >> (4 - 1);
569 return (x ^ sgn) - sgn;
570}
571#endif
572
573CK_TILE_HOST bool isnan(float x) { return std::isnan(x); };
574
575CK_TILE_HOST bool isnan(double x) { return std::isnan(x); };
576
578{
579 (void)x;
580 return false;
581};
582
584{
585 (void)x;
586 return false;
587};
588
590{
592
593 return (xx & 0x7FFF) > 0x7C00;
594};
595
596#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
597CK_TILE_HOST bool isnan(int4_t x)
598{
599 (void)x;
600 return false;
601};
602#endif
603
605{
606 return static_cast<fp16_t>(std::sqrt(static_cast<float>(x)));
607};
608
609CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
610
611CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
612
613template <typename T>
615{
616 return type_convert<T>(std::tanhf(type_convert<float>(x)));
617};
618
619template <>
621{
622 return std::tanhf(x);
623};
624
625template <>
627{
628 return std::tanh(x);
629};
630
631template <typename T>
633{
634 return type_convert<T>(std::acosf(type_convert<float>(x)));
635};
636
637template <>
639{
640 return std::acosf(x);
641};
642
643template <>
645{
646 return std::acos(x);
647};
648
649template <typename T>
651{
653};
654
655template <>
657{
658 return -x;
659};
660
661template <>
662CK_TILE_HOST double neg<double>(double x)
663{
664 return -x;
665};
666
667template <>
669{
670 return -x;
671};
672
673template <>
675{
676 return -x;
677};
678
679template <typename T>
681{
682 return type_convert<T>(std::atanf(type_convert<float>(x)));
683};
684
685template <>
687{
688 return std::atanf(x);
689};
690
691template <>
693{
694 return std::atan(x);
695};
696
697template <typename T>
699{
700 return type_convert<T>(std::sinf(type_convert<float>(x)));
701};
702
703template <>
705{
706 return std::sinf(x);
707};
708
709template <>
710CK_TILE_HOST double sin<double>(double x)
711{
712 return std::sin(x);
713};
714
715template <typename T>
717{
718 return type_convert<T>(std::asinf(type_convert<float>(x)));
719};
720
721template <>
723{
724 return std::asinf(x);
725};
726
727template <>
729{
730 return std::asin(x);
731};
732
733template <typename T>
735{
736 return type_convert<T>(std::asinhf(type_convert<float>(x)));
737};
738
739template <>
741{
742 return std::asinhf(x);
743};
744
745template <>
747{
748 return std::asinh(x);
749};
750
751template <typename T>
753{
754 return type_convert<T>(std::cosf(type_convert<float>(x)));
755};
756
757template <>
759{
760 return std::cosf(x);
761};
762
763template <>
764CK_TILE_HOST double cos<double>(double x)
765{
766 return std::cos(x);
767};
768
769template <typename T>
771{
772 return type_convert<T>(std::acoshf(type_convert<float>(x)));
773};
774
775template <>
777{
778 return std::acoshf(x);
779};
780
781template <>
783{
784 return std::acosh(x);
785};
786
787template <typename T>
789{
790 return type_convert<T>(std::tanf(type_convert<float>(x)));
791};
792
793template <>
795{
796 return std::tanf(x);
797};
798
799template <>
800CK_TILE_HOST double tan<double>(double x)
801{
802 return std::tan(x);
803};
804
805template <typename T>
807{
808 return type_convert<T>(std::atanhf(type_convert<float>(x)));
809};
810
811template <>
813{
814 return std::atanhf(x);
815};
816
817template <>
819{
820 return std::atanh(x);
821};
822
823template <typename T>
825{
826 return type_convert<T>(std::sinhf(type_convert<float>(x)));
827};
828
829template <>
831{
832 return std::sinhf(x);
833};
834
835template <>
837{
838 return std::sinh(x);
839};
840
841template <typename T>
843{
844 return type_convert<T>(std::ceilf(type_convert<float>(x)));
845};
846
847template <>
849{
850 return std::ceilf(x);
851};
852
853template <>
855{
856 return std::ceil(x);
857};
858
859template <typename T>
861{
862 return type_convert<T>(std::coshf(type_convert<float>(x)));
863};
864
865template <>
867{
868 return std::coshf(x);
869};
870
871template <>
873{
874 return std::cosh(x);
875};
876
877template <typename T>
879{
880 return type_convert<T>(std::floorf(type_convert<float>(x)));
881};
882
883template <>
885{
886 return std::floorf(x);
887};
888
889template <>
891{
892 return std::floor(x);
893};
894
895template <typename T>
897{
898 return type_convert<T>(1.f / type_convert<float>(x));
899};
900
901template <typename T>
903{
904 return type_convert<T>(std::expf(type_convert<float>(x)));
905}
906
907template <>
909{
910 return std::expf(x);
911}
912
913template <>
914CK_TILE_HOST double exp<double>(double x)
915{
916 return std::exp(x);
917}
918
919template <typename T>
921{
922 return type_convert<T>(std::logf(type_convert<float>(x)));
923}
924
925template <>
927{
928 return std::logf(x);
929}
930
931template <>
932CK_TILE_HOST double log<double>(double x)
933{
934 return std::log(x);
935}
936
937template <typename T>
938CK_TILE_HOST T pow(T x, T gamma)
939{
940 return type_convert<T>(std::powf(type_convert<float>(x), type_convert<float>(gamma)));
941}
942
943template <>
944CK_TILE_HOST float pow<float>(float x, float gamma)
945{
946 return std::powf(x, gamma);
947}
948
949template <>
950CK_TILE_HOST double pow<double>(double x, double gamma)
951{
952 return std::pow(x, gamma);
953}
954
955template <typename T>
957{
958 return type_convert<T>(std::expm1f(type_convert<float>(x)));
959}
960
961template <>
963{
964 return std::expm1f(x);
965}
966
967template <>
969{
970 return std::expm1(x);
971}
972
973// math functions for the HIP kernel, some are implemented by calling hip builtin functions
974
975CK_TILE_DEVICE float abs(float x)
976{
977 union
978 {
979 float f32;
980 uint32_t u32;
981 } y;
982 y.f32 = x;
983 y.u32 = y.u32 & 0x7fffffff;
984 return y.f32;
985};
986
987CK_TILE_DEVICE double abs(double x) { return ::abs(x); };
988
990{
991 int8_t sgn = x >> (8 - 1);
992
993 return (x ^ sgn) - sgn;
994};
995
997{
998 int32_t sgn = x >> (32 - 1);
999
1000 return (x ^ sgn) - sgn;
1001};
1002
1003#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
1004CK_TILE_DEVICE int4_t abs(int4_t x)
1005{
1006 int4_t sgn = x >> (4 - 1);
1007
1008 return (x ^ sgn) - sgn;
1009};
1010#endif
1011
1012CK_TILE_DEVICE fp16_t abs(fp16_t x)
1013{
1014 uint16_t xx = bit_cast<uint16_t>(x);
1015
1016 uint16_t abs_xx = xx & 0x7fff;
1017
1018 fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
1019
1020 return abs_x;
1021};
1022
1023CK_TILE_DEVICE bool isnan(float x) { return ::isnan(x); };
1024
1025CK_TILE_DEVICE bool isnan(double x) { return ::isnan(x); };
1026
1028{
1029 (void)x;
1030 return false;
1031};
1032
1034{
1035 (void)x;
1036 return false;
1037};
1038
1039#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
1040CK_TILE_DEVICE bool isnan(int4_t x)
1041{
1042 (void)x;
1043 return false;
1044};
1045#endif
1046
1047CK_TILE_DEVICE bool isnan(fp16_t x)
1048{
1049 uint16_t xx = bit_cast<uint16_t>(x);
1050
1051 return (xx & 0x7FFF) > 0x7C00;
1052};
1053
1054CK_TILE_DEVICE fp16_t sqrt(fp16_t x)
1055{
1056 return static_cast<fp16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
1057};
1058
1059CK_TILE_DEVICE float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
1060
1061CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
1062
1063template <typename T>
1065{
1066 return type_convert<T>(::tanhf(type_convert<float>(x)));
1067};
1068
1069template <>
1071{
1072 return ::tanhf(x);
1073};
1074
1075template <>
1077{
1078 return ::tanh(x);
1079};
1080
1081template <typename T>
1083{
1084 return type_convert<T>(::acosf(type_convert<float>(x)));
1085};
1086
1087template <>
1089{
1090 return ::acosf(x);
1091};
1092
1093template <>
1095{
1096 return ::acos(x);
1097};
1098
1099template <typename T>
1101{
1103};
1104
1105template <>
1107{
1108 return -x;
1109};
1110
1111template <>
1113{
1114 return -x;
1115};
1116
1117template <>
1119{
1120 return -x;
1121};
1122
1123template <>
1125{
1126 return -x;
1127};
1128
1129template <>
1131{
1132 return -x;
1133};
1134
1135template <typename T>
1137{
1138 return type_convert<T>(::atanf(type_convert<float>(x)));
1139};
1140
1141template <>
1143{
1144 return ::atanf(x);
1145};
1146
1147template <>
1149{
1150 return ::atan(x);
1151};
1152
1153template <typename T>
1155{
1156 return type_convert<T>(::sinf(type_convert<float>(x)));
1157};
1158
1159template <>
1161{
1162 return ::sinf(x);
1163};
1164
1165template <>
1167{
1168 return ::sin(x);
1169};
1170
1171template <>
1173{
1174 return __ocml_sin_f16(x);
1175};
1176
1177template <typename T>
1179{
1180 return type_convert<T>(::asinf(type_convert<float>(x)));
1181};
1182
1183template <>
1185{
1186 return ::asinf(x);
1187};
1188
1189template <>
1191{
1192 return ::asin(x);
1193};
1194
1195template <typename T>
1197{
1198 return type_convert<T>(::asinhf(type_convert<float>(x)));
1199};
1200
1201template <>
1203{
1204 return ::asinhf(x);
1205};
1206
1207template <>
1209{
1210 return ::asinh(x);
1211};
1212
1213template <typename T>
1215{
1216 return type_convert<T>(::acoshf(type_convert<float>(x)));
1217};
1218
1219template <>
1221{
1222 return ::acoshf(x);
1223};
1224
1225template <>
1227{
1228 return ::acosh(x);
1229};
1230
1231template <typename T>
1233{
1234 return type_convert<T>(::tanf(type_convert<float>(x)));
1235};
1236
1237template <>
1239{
1240 return ::tanf(x);
1241};
1242
1243template <>
1245{
1246 return ::tan(x);
1247};
1248
1249template <typename T>
1251{
1252 return type_convert<T>(::atanhf(type_convert<float>(x)));
1253};
1254
1255template <>
1257{
1258 return ::atanhf(x);
1259};
1260
1261template <>
1263{
1264 return ::atanh(x);
1265};
1266
1267template <typename T>
1269{
1270 return type_convert<T>(::sinhf(type_convert<float>(x)));
1271};
1272
1273template <>
1275{
1276 return ::sinhf(x);
1277};
1278
1279template <>
1281{
1282 return ::sinh(x);
1283};
1284
1285template <typename T>
1287{
1288 return type_convert<T>(::ceilf(type_convert<float>(x)));
1289};
1290
1291template <>
1293{
1294 return ::ceilf(x);
1295};
1296
1297template <>
1299{
1300 return ::ceil(x);
1301};
1302
1303template <>
1305{
1306 return __ocml_ceil_f16(x);
1307};
1308
1309template <typename T>
1311{
1312 return type_convert<T>(::coshf(type_convert<float>(x)));
1313};
1314
1315template <>
1317{
1318 return ::coshf(x);
1319};
1320
1321template <>
1323{
1324 return ::cosh(x);
1325};
1326
1327template <typename T>
1329{
1330 return type_convert<T>(::floorf(type_convert<float>(x)));
1331};
1332
1333template <>
1335{
1336 return ::floorf(x);
1337};
1338
1339template <>
1341{
1342 return ::floor(x);
1343};
1344
1345template <>
1347{
1348 return __ocml_floor_f16(x);
1349};
1350
1351template <typename T>
1353{
1354#if !CK_TILE_WORKAROUND_SWDEV_383542
1355 return __frcp_rn(x);
1356#else
1357 // return __ocml_native_recip_f32(x);
1358 return __builtin_amdgcn_rcpf(x);
1359#endif
1360};
1361
1362template <typename T>
1364{
1365 return type_convert<T>(__ocml_exp_f32(type_convert<float>(x)));
1366};
1367
1368template <>
1370{
1371 return __ocml_exp_f16(x);
1372};
1373
1374template <>
1376{
1377 return __ocml_exp_f32(x);
1378};
1379
1380template <>
1382{
1383 return exp(x);
1384};
1385
1386template <typename T>
1388{
1389 return type_convert<T>((exp<T>(2.0 * type_convert<float>(x)) - 1.0) /
1390 (exp<T>(2.0 * type_convert<float>(x)) + 1.0));
1391};
1392
1393template <>
1395{
1396 // float a = __builtin_amdgcn_sinh(x);
1397 // float b = __builtin_amdgcn_cosh(x);
1398 // float e = a * __builtin_amdgcn_rcpf(b);
1399 // return e;
1400
1401 float a = 2.0f * log2e_v<float> * x;
1402 a = __builtin_amdgcn_exp2f(a);
1403 a = __builtin_amdgcn_rcpf(a + 1.0f);
1404 a = 2 * a;
1405 a = 1 - a;
1406 return a;
1407
1408 // float e, r, s, t, d;
1409 // float a = x;
1410 // s = abs(a);
1411 // t = -log2e_v<float> * 2.0f * s;
1412 // e = __builtin_amdgcn_exp2f(t);
1413 // d = e + 1.0f;
1414 // r = __builtin_amdgcn_rcpf(d);
1415 // r = e * (-r) + r;
1416 // if (s < 4.997253418e-3f) r = a;
1417 // union fipnr {float f; unsigned int i;};
1418 // fipnr r_; r_.f = r;
1419 // fipnr a_; a_.f = a;
1420 // { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
1421 // return r;
1422};
1423
1424template <typename T>
1426{
1427 return type_convert<T>(__logf(type_convert<float>(x)));
1428};
1429
1430template <>
1432{
1433 return __ocml_log_f16(x);
1434};
1435
1436template <>
1438{
1439 return __logf(x);
1440};
1441
1442template <>
1444{
1445 return log(x);
1446};
1447
1448template <typename T>
1449CK_TILE_DEVICE T pow(T x, T gamma)
1450{
1452};
1453
1454template <>
1455CK_TILE_DEVICE float pow<float>(float x, float gamma)
1456{
1457 return powf(x, gamma);
1458};
1459
1460template <>
1461CK_TILE_DEVICE double pow<double>(double x, double gamma)
1462{
1463 return pow(x, gamma);
1464};
1465
1466template <typename T>
1468{
1469 return type_convert<T>(expm1f(type_convert<float>(x)));
1470};
1471
1472template <>
1474{
1475 return expm1f(x);
1476};
1477
1478template <>
1480{
1481 return expm1(x);
1482};
1483
1484} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
__host__ __device__ less_equal() -> less_equal< void, void >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
Definition tile/core/numeric/math.hpp:462
CK_TILE_HOST T acos(T x)
Definition tile/core/numeric/math.hpp:632
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
Definition tile/core/numeric/math.hpp:143
int8_t int8_t
Definition int8.hpp:20
CK_TILE_HOST T cos(T x)
Definition tile/core/numeric/math.hpp:752
_Float16 fp16_t
Definition half.hpp:110
CK_TILE_HOST_DEVICE constexpr T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition tile/core/numeric/math.hpp:259
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_HOST T ceil(T x)
Definition tile/core/numeric/math.hpp:842
CK_TILE_HOST T acosh(T x)
Definition tile/core/numeric/math.hpp:770
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST T expm1(T x)
Definition tile/core/numeric/math.hpp:956
CK_TILE_DEVICE T tanh_fast(T x)
Definition tile/core/numeric/math.hpp:1387
__host__ __device__ minus() -> minus< void, void >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_HOST T tanh(T x)
Definition tile/core/numeric/math.hpp:614
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition bfloat16.hpp:413
__host__ __device__ less() -> less< void, void >
FIXME: create macro to replace 'host device' and nothing more.
__host__ __device__ scales(Scale) -> scales< Scale >
FIXME: create macro to replace 'host device' and nothing more.
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST T atan(T x)
Definition tile/core/numeric/math.hpp:680
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
Definition tile/core/numeric/math.hpp:268
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST T sin(T x)
Definition tile/core/numeric/math.hpp:698
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
__host__ __device__ equal() -> equal< void, void >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_DEVICE uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
Definition tile/core/numeric/math.hpp:504
CK_TILE_HOST T floor(T x)
Definition tile/core/numeric/math.hpp:878
CK_TILE_HOST T sinh(T x)
Definition tile/core/numeric/math.hpp:824
constexpr T log2e_rcp_v
Definition tile/core/numeric/math.hpp:491
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two()
Definition tile/core/numeric/math.hpp:442
CK_TILE_HOST T asin(T x)
Definition tile/core/numeric/math.hpp:716
CK_TILE_HOST T asinh(T x)
Definition tile/core/numeric/math.hpp:734
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_HOST int clz(uint32_t x)
Definition tile/core/numeric/math.hpp:264
CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
Definition tile/core/numeric/math.hpp:499
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
Definition tile/core/numeric/math.hpp:455
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition bfloat16.hpp:400
CK_TILE_HOST T atanh(T x)
Definition tile/core/numeric/math.hpp:806
CK_TILE_HOST T neg(T x)
Definition tile/core/numeric/math.hpp:650
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition bfloat16.hpp:406
CK_TILE_HOST T tan(T x)
Definition tile/core/numeric/math.hpp:788
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST T cosh(T x)
Definition tile/core/numeric/math.hpp:860
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
Definition tile/core/numeric/math.hpp:314
CK_TILE_HOST T pow(T x, T gamma)
Definition tile/core/numeric/math.hpp:938
CK_TILE_HOST T rcp(T x)
Definition tile/core/numeric/math.hpp:896
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
__host__ __device__ plus() -> plus< void, void >
FIXME: create macro to replace 'host device' and nothing more.
_BitInt(4) int4_t
Definition data_type.hpp:32
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
unsigned short uint16_t
Definition stdint.h:125
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const
Definition tile/core/numeric/math.hpp:363
CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const
Definition tile/core/numeric/math.hpp:354
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs==rhs)
Definition tile/core/numeric/math.hpp:341
Definition tile/core/numeric/math.hpp:329
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs==rhs)
Definition tile/core/numeric/math.hpp:330
Definition tile/core/numeric/math.hpp:134
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const
Definition tile/core/numeric/math.hpp:135
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs< rhs)
Definition tile/core/numeric/math.hpp:383
CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const
Definition tile/core/numeric/math.hpp:429
CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const
Definition tile/core/numeric/math.hpp:420
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs<=rhs)
Definition tile/core/numeric/math.hpp:407
Definition tile/core/numeric/math.hpp:395
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs<=rhs)
Definition tile/core/numeric/math.hpp:396
Definition tile/core/numeric/math.hpp:371
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs< rhs)
Definition tile/core/numeric/math.hpp:372
static constexpr double value
Definition tile/core/numeric/math.hpp:478
static constexpr float value
Definition tile/core/numeric/math.hpp:484
Definition tile/core/numeric/math.hpp:473
Definition tile/core/numeric/math.hpp:122
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const
Definition tile/core/numeric/math.hpp:123
Definition tile/core/numeric/math.hpp:128
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const
Definition tile/core/numeric/math.hpp:129
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs - rhs)
Definition tile/core/numeric/math.hpp:86
Definition tile/core/numeric/math.hpp:74
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs - rhs)
Definition tile/core/numeric/math.hpp:75
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs *rhs)
Definition tile/core/numeric/math.hpp:110
Definition tile/core/numeric/math.hpp:98
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs *rhs)
Definition tile/core/numeric/math.hpp:99
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs+rhs)
Definition tile/core/numeric/math.hpp:62
Definition tile/core/numeric/math.hpp:50
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left &lhs, const Right &rhs) const -> decltype(lhs+rhs)
Definition tile/core/numeric/math.hpp:51
Definition tile/core/numeric/math.hpp:18
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right &rhs) const -> decltype(lhs *rhs)
Definition tile/core/numeric/math.hpp:20
Definition tile/core/numeric/math.hpp:28
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right &rhs) const -> decltype(std::declval< const Scale & >() *rhs)
Definition tile/core/numeric/math.hpp:35
CK_TILE_HOST_DEVICE constexpr scales(Scale lhs)
Definition tile/core/numeric/math.hpp:31
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469