device_gemm_multiple_d_ab_scale.hpp Source File

device_gemm_multiple_d_ab_scale.hpp Source File#

Composable Kernel: device_gemm_multiple_d_ab_scale.hpp Source File
device_gemm_multiple_d_ab_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7
9
10namespace ck {
11namespace tensor_operation {
12namespace device {
13
14// GEMM:
15// input : A[M, K], B[K, N],
16// input : D0[M, N], D1[M, N], ...
17// output : E[M, N]
18// C = a_op(A) * b_op(B)
19// E = cde_op(C, D0, D1, ...)
20// Assume:
21// D0, D1, ... and E have the same layout
22template <typename ALayout,
23 typename BLayout,
24 typename DsLayout,
25 typename ELayout,
26 typename ADataType,
27 typename AScaleType,
28 typename BDataType,
29 typename BScaleType,
30 typename DsDataType,
31 typename EDataType,
32 index_t ScaleBlockM,
33 index_t ScaleBlockN,
34 index_t ScaleBlockK,
35 typename AElementwiseOperation,
36 typename BElementwiseOperation,
37 typename CDEElementwiseOperation>
39{
40 static constexpr index_t NumDTensor = DsDataType::Size();
41
42 virtual std::unique_ptr<BaseArgument>
43 MakeArgumentPointer(const void* p_a,
44 const void* p_b,
45 std::array<const void*, NumDTensor> p_ds,
46 void* p_e,
47 const ck::index_t M,
48 const ck::index_t N,
49 const ck::index_t K,
50 const ck::index_t StrideA,
51 const ck::index_t StrideB,
52 const std::array<ck::index_t, NumDTensor> StrideDs,
53 const ck::index_t StrideE,
54 const void* p_a_scale,
55 const void* p_b_scale,
56 AElementwiseOperation a_element_op,
57 BElementwiseOperation b_element_op,
58 CDEElementwiseOperation cde_element_op) = 0;
59
60 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
61
62 virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0;
63};
64
65template <typename ALayout,
66 typename BLayout,
67 typename DsLayout,
68 typename ELayout,
69 typename ADataType,
70 typename AScaleType,
71 typename BDataType,
72 typename BScaleType,
73 typename DsDataType,
74 typename EDataType,
75 index_t ScaleBlockM,
76 index_t ScaleBlockN,
77 index_t ScaleBlockK,
78 typename AElementwiseOperation,
79 typename BElementwiseOperation,
80 typename CDEElementwiseOperation>
82{
83 static constexpr index_t NumDTensor = DsDataType::Size();
84
85 virtual std::unique_ptr<BaseArgument>
86 MakeArgumentPointer(const void* p_a,
87 const void* p_b,
88 std::array<const void*, NumDTensor> p_ds,
89 void* p_e,
90 const ck::index_t M,
91 const ck::index_t N,
92 const ck::index_t K,
93 const ck::index_t StrideA,
94 const ck::index_t StrideB,
95 const std::array<ck::index_t, NumDTensor> StrideDs,
96 const ck::index_t StrideE,
97 const void* p_a_scale,
98 const void* p_b_scale,
99 AElementwiseOperation a_element_op,
100 BElementwiseOperation b_element_op,
101 CDEElementwiseOperation cde_element_op) = 0;
102
103 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
104
105 virtual int GetPreShuffleParameters() = 0;
106};
107
108} // namespace device
109} // namespace tensor_operation
110} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_base.hpp:197
Definition device_gemm_multiple_d_ab_scale.hpp:39
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_ab_scale.hpp:40
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual void SetKBatch(BaseArgument *arg, int KBatch) const =0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const ck::index_t M, const ck::index_t N, const ck::index_t K, const ck::index_t StrideA, const ck::index_t StrideB, const std::array< ck::index_t, NumDTensor > StrideDs, const ck::index_t StrideE, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
Definition device_gemm_multiple_d_ab_scale.hpp:82
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_ab_scale.hpp:83
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const ck::index_t M, const ck::index_t N, const ck::index_t K, const ck::index_t StrideA, const ck::index_t StrideB, const std::array< ck::index_t, NumDTensor > StrideDs, const ck::index_t StrideE, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0