device_grouped_conv_fwd.hpp Source File

device_grouped_conv_fwd.hpp Source File#

Composable Kernel: device_grouped_conv_fwd.hpp Source File
device_grouped_conv_fwd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7
9
10namespace ck {
11namespace tensor_operation {
12namespace device {
13
14// Convolution Forward:
15// input : input image A[G, N, C, Hi, Wi],
16// input : weight B[G, K, C, Y, X],
17// output : output image E[G, N, K, Ho, Wo]
18// C = a_op(A) * b_op(B)
19// E = cde_op(C, D0, D1, ...)
20template <index_t NDimSpatial,
21 typename InLayout,
22 typename WeiLayout,
23 typename OutLayout,
24 typename InDataType,
25 typename WeiDataType,
26 typename OutDataType,
27 typename InElementwiseOperation,
28 typename WeiElementwiseOperation,
29 typename OutElementwiseOperation>
31{
32 virtual std::unique_ptr<BaseArgument>
33 MakeArgumentPointer(const void* p_in, // input image
34 const void* p_wei, // weight
35 void* p_out, // output image
36 const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
37 const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
38 const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
39 const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_strides,
40 const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
41 const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides,
42 const std::array<index_t, NDimSpatial>& conv_filter_strides,
43 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
44 const std::array<index_t, NDimSpatial>& input_left_pads,
45 const std::array<index_t, NDimSpatial>& input_right_pads,
46 const InElementwiseOperation& in_element_op,
47 const WeiElementwiseOperation& wei_element_op,
48 const OutElementwiseOperation& out_element_op) = 0;
49
50 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
51};
52
53} // namespace device
54} // namespace tensor_operation
55} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_grouped_conv_fwd.hpp:31
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, const void *p_wei, void *p_out, const std::array< index_t, NDimSpatial+3 > &in_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &in_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &wei_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &wei_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &out_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &out_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const InElementwiseOperation &in_element_op, const WeiElementwiseOperation &wei_element_op, const OutElementwiseOperation &out_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0