binary_element_wise_operation.hpp Source File

binary_element_wise_operation.hpp Source File#

Composable Kernel: binary_element_wise_operation.hpp Source File
binary_element_wise_operation.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10namespace tensor_operation {
11namespace element_wise {
12
13struct Add
14{
15 static constexpr const char* name = "Add";
16
17 template <typename Y, typename X0, typename X1>
18 __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
19
20 template <>
21 __host__ __device__ constexpr void
22 operator()<float>(float& y, const float& x0, const float& x1) const
23 {
24 y = x0 + x1;
25 };
26
27 template <>
28 __host__ __device__ constexpr void
29 operator()<double>(double& y, const double& x0, const double& x1) const
30 {
31 y = x0 + x1;
32 };
33
34 template <>
35 __host__ __device__ constexpr void
36 operator()<float>(float& y, const float& x0, const half_t& x1) const
37 {
38 y = x0 + type_convert<half_t>(x1);
39 };
40
41 template <>
42 __host__ __device__ constexpr void
43 operator()<half_t>(half_t& y, const float& x0, const float& x1) const
44 {
45 y = type_convert<half_t>(x0 + x1);
46 };
47
48 template <>
49 __host__ __device__ constexpr void
50 operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
51 {
52 y = x0 + type_convert<float>(x1);
53 };
54
55 template <>
56 __host__ __device__ constexpr void
57 operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
58 {
59 y = x0 + x1;
60 };
61
62 template <>
63 __host__ __device__ constexpr void
64 operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
65 {
66 const float x1_tmp = ck::type_convert<float>(x1);
67 y = x0 + x1_tmp;
68 }
69
70 template <>
71 __host__ __device__ constexpr void
72 operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
73 {
74 const float x1_tmp = ck::type_convert<float>(x0);
75 const float x2_tmp = ck::type_convert<float>(x1);
76 const float y_tmp = x1_tmp + x2_tmp;
78 }
79
80 template <>
81 __host__ __device__ constexpr void
82 operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
83 {
84 const float x2_tmp = ck::type_convert<float>(x1);
85 const float y_tmp = x0 + x2_tmp;
87 }
88
89 template <>
90 __host__ __device__ constexpr void
91 operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
92 {
93 y = x0 + x1;
94 };
95};
96
97struct Max
98{
99 static constexpr const char* name = "Max";
100
101 template <typename Y, typename X0, typename X1>
102 __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
103 {
104 const Y x0_converted = type_convert<Y>(x0);
105 const Y x1_converted = type_convert<Y>(x1);
106 y = ck::math::max(x0_converted, x1_converted);
107 }
108};
109
110struct Min
111{
112 static constexpr const char* name = "Min";
113
114 template <typename Y, typename X0, typename X1>
115 __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
116 {
117 const Y x0_converted = type_convert<Y>(x0);
118 const Y x1_converted = type_convert<Y>(x1);
119 y = ck::math::min(x0_converted, x1_converted);
120 }
121};
122
124{
125 static constexpr const char* name = "Multiply";
126
127 template <typename Y, typename X0, typename X1>
128 __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
129
130 template <>
131 __host__ __device__ constexpr void
132 operator()<float>(float& y, const float& x0, const float& x1) const
133 {
134 y = x0 * x1;
135 };
136
137 template <>
138 __host__ __device__ constexpr void
139 operator()<double>(double& y, const double& x0, const double& x1) const
140 {
141 y = x0 * x1;
142 };
143
144 template <>
145 __host__ __device__ constexpr void
146 operator()<float>(float& y, const float& x0, const half_t& x1) const
147 {
148 y = x0 * type_convert<half_t>(x1);
149 };
150
151 template <>
152 __host__ __device__ constexpr void
153 operator()<half_t>(half_t& y, const float& x0, const float& x1) const
154 {
155 y = type_convert<half_t>(x0 * x1);
156 };
157
158 template <>
159 __host__ __device__ constexpr void
160 operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
161 {
162 y = type_convert<half_t>(x0) * x1;
163 };
164
165 template <>
166 __host__ __device__ constexpr void
167 operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
168 {
169 y = x0 * x1;
170 };
171
172 template <>
173 __host__ __device__ constexpr void
174 operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
175 {
176 const float x1_tmp = ck::type_convert<float>(x1);
177 y = x0 * x1_tmp;
178 }
179
180 template <>
181 __host__ __device__ constexpr void
182 operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
183 {
184 const float x1_tmp = ck::type_convert<float>(x0);
185 const float x2_tmp = ck::type_convert<float>(x1);
186 const float y_tmp = x1_tmp * x2_tmp;
187 y = ck::type_convert<bhalf_t>(y_tmp);
188 }
189
190 template <>
191 __host__ __device__ constexpr void
192 operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
193 {
194 const float x1_tmp = ck::type_convert<float>(x0);
195 const float x2_tmp = ck::type_convert<float>(x1);
196 const float y_tmp = x1_tmp * x2_tmp;
197 y = ck::type_convert<bhalf_t>(y_tmp);
198 }
199
200 template <>
201 __host__ __device__ constexpr void
202 operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
203 {
204 const float x2_tmp = ck::type_convert<float>(x1);
205 const float y_tmp = x0 * x2_tmp;
206 y = ck::type_convert<bhalf_t>(y_tmp);
207 }
208
209 template <>
210 __host__ __device__ constexpr void
211 operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
212 {
213 y = x0 * x1;
214 };
215};
216
218{
219 static constexpr const char* name = "ScaleAdd";
220
221 __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
222
223 template <typename Y, typename X0, typename X1>
224 __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
225 {
227 }
228
229 template <>
230 __host__ __device__ void
231 operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
232 {
233 y = scale_ * x0 + ck::type_convert<float>(x1);
234 };
235
236 template <>
237 __host__ __device__ void
238 operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
239 {
240 y = scale_ * x0 + ck::type_convert<float>(x1);
241 };
242
243 float scale_;
244};
245
247{
248 static constexpr const char* name = "Subtract";
249
250 template <typename T>
251 __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
252
253 template <>
254 __host__ __device__ constexpr void
255 operator()<float>(float& y, const float& x0, const float& x1) const
256 {
257 y = x0 - x1;
258 };
259
260 template <>
261 __host__ __device__ constexpr void
262 operator()<double>(double& y, const double& x0, const double& x1) const
263 {
264 y = x0 - x1;
265 };
266
267 template <>
268 __host__ __device__ constexpr void
269 operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
270 {
271 y = x0 - x1;
272 };
273
274 template <>
275 __host__ __device__ constexpr void
276 operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
277 {
278 const float x1_tmp = ck::type_convert<float>(x0);
279 const float x2_tmp = ck::type_convert<float>(x1);
280 const float y_tmp = x1_tmp - x2_tmp;
281 y = ck::type_convert<bhalf_t>(y_tmp);
282 }
283
284 template <>
285 __host__ __device__ constexpr void
286 operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
287 {
288 y = x0 - x1;
289 };
290};
291
293{
294 static constexpr const char* name = "Bilinear";
295
296 Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
297
298 template <typename Y, typename X0, typename X1>
299 __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
300
301 template <>
302 __host__ __device__ constexpr void
303 operator()<double, double, double>(double& y, const double& x0, const double& x1) const
304 {
305 y = alpha_ * x0 + beta_ * x1;
306 };
307
308 template <>
309 __host__ __device__ constexpr void
310 operator()<float, float, float>(float& y, const float& x0, const float& x1) const
311 {
312 y = alpha_ * x0 + beta_ * x1;
313 };
314
315 template <>
316 __host__ __device__ constexpr void
317 operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
318 {
321 };
322
323 template <>
324 __host__ __device__ constexpr void
325 operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
326 {
328 };
329
330 template <>
331 __host__ __device__ constexpr void
332 operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
333 {
335 };
336
337 template <>
338 __host__ __device__ constexpr void
339 operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
340 {
341 const float x0_tmp = type_convert<float>(x0);
342 const float x1_tmp = type_convert<float>(x1);
343 const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
344 y = type_convert<bhalf_t>(y_tmp);
345 };
346
347 template <>
348 __host__ __device__ constexpr void
349 operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
350 {
351 const float x1_tmp = ck::type_convert<float>(x1);
352 const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
353 y = y_tmp;
354 };
355
356 template <>
357 __host__ __device__ constexpr void
358 operator()<int8_t, int32_t, int8_t>(int8_t& y, const int32_t& x0, const int8_t& x1) const
359 {
362 };
363
364 float alpha_;
365 float beta_;
366};
367
369{
370 static constexpr const char* name = "AddClamp";
371
372 AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
373 : floor_(floor), ceil_(ceil){};
374
375 template <typename Y, typename X0, typename X1>
376 __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
377
378 template <>
379 __host__ __device__ constexpr void
380 operator()<float, float, float>(float& y, const float& x0, const float& x1) const
381 {
382 const float a = x0 + x1;
383 y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
384 };
385
386 template <>
387 __host__ __device__ constexpr void
388 operator()<double, double, double>(double& y, const double& x0, const double& x1) const
389 {
390 const double a = x0 + x1;
391 y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
392 };
393
394 template <>
395 __host__ __device__ constexpr void
396 operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
397 {
398 const half_t floor = type_convert<half_t>(floor_);
399 const half_t ceil = type_convert<half_t>(ceil_);
400 const half_t a = x0 + x1;
401 y = a > floor ? (a < ceil ? a : ceil) : floor;
402 };
403
404 template <>
405 __host__ __device__ constexpr void
406 operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
407 {
408 const float a = x0 + type_convert<float>(x1);
409 const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
411 };
412
413 template <>
414 __host__ __device__ constexpr void
415 operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
416 {
417 const float a = x0 + type_convert<float>(x1);
418 y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
419 };
420
421 template <>
422 __host__ __device__ constexpr void
423 operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
424 {
425 const float a = x0 + type_convert<float>(x1);
426 const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
428 };
429
430 template <>
431 __host__ __device__ constexpr void
432 operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
433 {
434 const float a = type_convert<float>(x0) + type_convert<float>(x1);
435 const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
437 };
438
439 template <>
440 __host__ __device__ constexpr void
441 operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
442 {
443 const int8_t a = x0 + x1;
444 y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
445 };
446
447 template <>
448 __host__ __device__ constexpr void
449 operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
450 {
451 const int8_t a = x0 + x1;
452 y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
453 };
454
455 const float floor_;
456 const float ceil_;
457};
458
460{
461 static constexpr const char* name = "AddRelu";
462
463 template <typename Y, typename X0, typename X1>
464 __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
465
466 template <>
467 __host__ __device__ constexpr void
468 operator()<float, float, float>(float& y, const float& x0, const float& x1) const
469 {
470 const float a = x0 + x1;
471 y = a > 0.0f ? a : 0.0f;
472 };
473
474 template <>
475 __host__ __device__ constexpr void
476 operator()<double, double, double>(double& y, const double& x0, const double& x1) const
477 {
478 const double a = x0 + x1;
479 y = a > 0.0 ? a : 0.0;
480 };
481
482 template <>
483 __host__ __device__ constexpr void
484 operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
485 {
486 const half_t a = x0 + x1;
487 y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
488 };
489
490 template <>
491 __host__ __device__ constexpr void
492 operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
493 {
494 const float a = x0 + type_convert<float>(x1);
495 const float b = a > 0.0f ? a : 0.0f;
497 };
498
499 template <>
500 __host__ __device__ constexpr void
501 operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
502 {
503 const float a = x0 + type_convert<float>(x1);
504 y = a > 0.0f ? a : 0.0f;
505 };
506
507 template <>
508 __host__ __device__ constexpr void
509 operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
510 {
511 const float a = x0 + type_convert<float>(x1);
512 const float b = a > 0.0f ? a : 0.0f;
514 };
515
516 template <>
517 __host__ __device__ constexpr void
518 operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
519 {
520 const float a = type_convert<float>(x0) + type_convert<float>(x1);
521 const float b = a > 0.0f ? a : 0.0f;
523 };
524
525 template <>
526 __host__ __device__ constexpr void
527 operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
528 {
529 const int8_t a = x0 + x1;
530 y = a > 0 ? a : 0;
531 };
532
533 template <>
534 __host__ __device__ constexpr void
535 operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
536 {
537 const int8_t a = x0 + x1;
538 y = a > 0 ? a : 0;
539 };
540};
541
543{
544 static constexpr const char* name = "AddHardswish";
545
546 template <typename T>
547 __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
548
549 template <>
550 __host__ __device__ constexpr void
551 operator()<float>(float& y, const float& x0, const float& x1) const
552 {
553 float a = x0 + x1;
554 float b = a + float{3};
555 float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
556 y = c;
557 };
558
559 template <>
560 __host__ __device__ constexpr void
561 operator()<double>(double& y, const double& x0, const double& x1) const
562 {
563 double a = x0 + x1;
564 double b = a + 3.0;
565 double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667;
566 y = c;
567 };
568
569 template <>
570 __host__ __device__ constexpr void
571 operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
572 {
573 float a = x0 + x1;
574 float b = a + 3.0f;
575 float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
576 y = c;
577 };
578};
579
580// E = FastGelu(C + D)
582{
583 static constexpr const char* name = "AddFastGelu";
584
585 template <typename E, typename C, typename D>
586 __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
587
588 template <>
589 __host__ __device__ constexpr void
590 operator()<float, float, float>(float& e, const float& c, const float& d) const
591 {
592 const float x = c + d;
593
594 FastGelu{}.template operator()<float, float>(e, x);
595 }
596
597 template <>
598 __host__ __device__ constexpr void
599 operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
600 {
601 const half_t x = c + d;
602
603 ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
604 }
605
606 template <>
607 __host__ __device__ constexpr void
608 operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
609 {
610 const float x0_f = c + d;
611
612 float x1_f = 0;
613
614 ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
615 x0_f);
616
617 e = type_convert<half_t>(x1_f);
618 }
619
620 template <>
621 __host__ __device__ constexpr void
622 operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
623 {
624 const float x0_f = type_convert<float>(c) + type_convert<float>(d);
625
626 float x1_f = 0;
627
628 FastGelu{}.template operator()<float, float>(x1_f, x0_f);
629
630 e = type_convert<bhalf_t>(x1_f);
631 }
632
633 template <>
634 __host__ __device__ constexpr void
635 operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
636 {
637 const float x0_f = c + type_convert<float>(d);
638
639 float x1_f = 0;
640
641 FastGelu{}.template operator()<float, float>(x1_f, x0_f);
642
643 e = type_convert<bhalf_t>(x1_f);
644 }
645};
646
647// E = MultiplyFastGelu(C + D)
649{
650 static constexpr const char* name = "MultiplyFastGelu";
651
652 template <typename E, typename C, typename D>
653 __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
654
655 template <>
656 __host__ __device__ constexpr void
657 operator()<float, float, float>(float& e, const float& c, const float& d) const
658 {
659 const float x = c * d;
660
661 FastGelu{}.template operator()<float, float>(e, x);
662 }
663
664 template <>
665 __host__ __device__ constexpr void
666 operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
667 {
668 const half_t x = c * d;
669
670 ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
671 }
672
673 template <>
674 __host__ __device__ constexpr void
675 operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
676 {
677 const float x0_f = c * d;
678
679 float x1_f = 0;
680
681 ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
682 x0_f);
683
684 e = type_convert<half_t>(x1_f);
685 }
686
687 template <>
688 __host__ __device__ constexpr void
689 operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
690 {
691 const float x0_f = type_convert<float>(c) * type_convert<float>(d);
692
693 float x1_f = 0;
694
695 FastGelu{}.template operator()<float, float>(x1_f, x0_f);
696
697 e = type_convert<bhalf_t>(x1_f);
698 }
699
700 template <>
701 __host__ __device__ constexpr void
702 operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
703 {
704 const float x0_f = c * type_convert<float>(d);
705
706 float x1_f = 0;
707
708 FastGelu{}.template operator()<float, float>(x1_f, x0_f);
709
710 e = type_convert<bhalf_t>(x1_f);
711 }
712};
713
714// E = Silu(C + D)
716{
717 static constexpr const char* name = "AddSilu";
718
719 template <typename E, typename C, typename D>
720 __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
721
722 template <>
723 __host__ __device__ constexpr void
724 operator()<float, float, float>(float& e, const float& c, const float& d) const
725 {
726 const float x = c + d;
727
728 Silu{}.template operator()<float>(e, x);
729 }
730
731 template <>
732 __host__ __device__ constexpr void
733 operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
734 {
735 const half_t x = c + d;
736
737 Silu{}.template operator()<half_t>(e, x);
738 }
739
740 template <>
741 __host__ __device__ constexpr void
742 operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
743 {
744 const float x0_f = c + d;
745
746 float x1_f = 0;
747
748 Silu{}.template operator()<float>(x1_f, x0_f);
749
750 e = type_convert<half_t>(x1_f);
751 }
752
753 template <>
754 __host__ __device__ constexpr void
755 operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
756 {
757 const float x0_f = c + type_convert<float>(d);
758
759 float x1_f = 0;
760
761 Silu{}.template operator()<float>(x1_f, x0_f);
762
763 e = type_convert<bhalf_t>(x1_f);
764 }
765};
766
768{
769 static constexpr const char* name = "ConvScaleAdd";
770
771 __host__ __device__ ConvScaleAdd(float scale_in = 1.f,
772 float scale_wei = 1.f,
773 float scale_out = 1.f)
774 : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
775 {
776 }
777
778 template <typename E, typename C, typename D>
779 __host__ __device__ void operator()(E& e, const C& c, const D& d) const;
780
781 template <>
782 __host__ __device__ void
783 operator()<f8_t, float, float>(f8_t& e, const float& c, const float& d) const
784 {
785 float x;
786 Add{}.template operator()<float>(x, c * scale_in_ * scale_wei_, d);
788 };
789
793};
794
795} // namespace element_wise
796} // namespace tensor_operation
797} // namespace ck
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition binary_element_wise_operation.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
__host__ static __device__ constexpr T Max()
Definition numeric_limits.hpp:311
static constexpr const char * name
Definition binary_element_wise_operation.hpp:370
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const
AddClamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition binary_element_wise_operation.hpp:372
const float ceil_
Definition binary_element_wise_operation.hpp:456
const float floor_
Definition binary_element_wise_operation.hpp:455
Definition binary_element_wise_operation.hpp:582
__host__ __device__ constexpr void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition binary_element_wise_operation.hpp:583
Definition binary_element_wise_operation.hpp:543
__host__ __device__ constexpr void operator()(T &y, const T &x0, const T &x1) const
static constexpr const char * name
Definition binary_element_wise_operation.hpp:544
Definition binary_element_wise_operation.hpp:14
static constexpr const char * name
Definition binary_element_wise_operation.hpp:15
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition binary_element_wise_operation.hpp:460
static constexpr const char * name
Definition binary_element_wise_operation.hpp:461
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition binary_element_wise_operation.hpp:716
__host__ __device__ constexpr void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition binary_element_wise_operation.hpp:717
__host__ __device__ constexpr void operator()(Y &, const X0 &, const X1 &) const
Bilinear(float alpha=1.f, float beta=1.f)
Definition binary_element_wise_operation.hpp:296
static constexpr const char * name
Definition binary_element_wise_operation.hpp:294
float beta_
Definition binary_element_wise_operation.hpp:365
float alpha_
Definition binary_element_wise_operation.hpp:364
float scale_in_
Definition binary_element_wise_operation.hpp:790
float scale_wei_
Definition binary_element_wise_operation.hpp:791
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition binary_element_wise_operation.hpp:771
float scale_out_
Definition binary_element_wise_operation.hpp:792
static constexpr const char * name
Definition binary_element_wise_operation.hpp:769
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:924
Definition binary_element_wise_operation.hpp:98
static constexpr const char * name
Definition binary_element_wise_operation.hpp:99
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition binary_element_wise_operation.hpp:102
Definition binary_element_wise_operation.hpp:111
static constexpr const char * name
Definition binary_element_wise_operation.hpp:112
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition binary_element_wise_operation.hpp:115
Definition binary_element_wise_operation.hpp:649
static constexpr const char * name
Definition binary_element_wise_operation.hpp:650
__host__ __device__ constexpr void operator()(E &e, const C &c, const D &d) const
Definition binary_element_wise_operation.hpp:124
static constexpr const char * name
Definition binary_element_wise_operation.hpp:125
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition binary_element_wise_operation.hpp:224
float scale_
Definition binary_element_wise_operation.hpp:243
__host__ __device__ ScaleAdd(float scale=1.f)
Definition binary_element_wise_operation.hpp:221
static constexpr const char * name
Definition binary_element_wise_operation.hpp:219
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1087
Definition binary_element_wise_operation.hpp:247
static constexpr const char * name
Definition binary_element_wise_operation.hpp:248
__host__ __device__ constexpr void operator()(T &y, const T &x0, const T &x1) const