topk_softmax_kernel.hpp Source File

topk_softmax_kernel.hpp Source File#

Composable Kernel: topk_softmax_kernel.hpp Source File
topk_softmax_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10#include <string>
11#include <type_traits>
12
13namespace ck_tile {
14
16{
17 const void* p_input;
18 void* p_output;
19 void* p_indices;
23 index_t stride_input; // row stride for input, at least experts
24 index_t stride_output; // row stride for output/indices, at least topk
25};
26
27template <typename Pipeline_>
29{
32
33 using InputType = typename Problem::InputType;
34 using WeightType = typename Problem::WeightType;
35 using IndexType = typename Problem::IndexType;
36
37 static constexpr index_t kBlockSize = Problem::BlockSize;
38
40 {
41 const void* p_input;
42 void* p_output;
43 void* p_indices;
47 index_t stride_input; // row stride for input, at least experts
48 index_t stride_output; // row stride for output/indices, at least topk
49 };
50
53
54 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
55 {
56 if constexpr(Problem::LaunchType > 0)
57 {
58 int num_cu = [&]() {
59 hipDeviceProp_t dev_prop;
60 hipDevice_t dev;
61 HIP_CHECK_ERROR(hipGetDevice(&dev));
62 HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
63 return dev_prop.multiProcessorCount;
64 }();
65 return dim3(num_cu * Problem::LaunchType);
66 }
67 else
68 {
69 const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
70 const int num_blocks =
71 (num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
72 return dim3(num_blocks);
73 }
74 }
75
76 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
77 {
78 Kargs k;
79 k.p_input = h.p_input;
80 k.p_output = h.p_output;
81 k.p_indices = h.p_indices;
82 k.num_rows = h.num_rows;
84 k.topk = h.topk;
87 return k;
88 }
89
90 CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
91
93 {
94 index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
95
96 if(block_row_id > kargs.num_rows)
97 return;
98
99 index_t block_os_inp = amd_wave_read_first_lane(block_row_id * kargs.stride_input);
100 index_t block_os_out = amd_wave_read_first_lane(block_row_id * kargs.stride_output);
101 index_t num_rows_rem = amd_wave_read_first_lane(kargs.num_rows - block_row_id);
102
103 const auto input_window = [&]() {
104 const InputType* p_input =
105 reinterpret_cast<const InputType*>(kargs.p_input) + block_os_inp;
106
108 p_input,
109 make_tuple(num_rows_rem, kargs.num_experts),
110 make_tuple(kargs.stride_input, 1),
112 number<1>{});
113
114 auto view = pad_tensor_view(
115 tmp,
117 sequence<0, 1>{}); // out-most dim no need pad(leverage oob)
118
119 return make_tile_window(
120 view,
122 {0, 0});
123 }();
124
125 auto output_window = [&]() {
126 WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) + block_os_out;
128 p_output,
129 make_tuple(num_rows_rem, kargs.topk),
130 make_tuple(kargs.stride_output, 1),
132 number<1>{});
133 auto view =
134 pad_tensor_view(tmp,
136 sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
137 // 2. we loop over topk 1-1, no need padding
138 return make_tile_window(
140 }();
141
142 auto indices_window = [&]() {
143 IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) + block_os_out;
145 p_indices,
146 make_tuple(num_rows_rem, kargs.topk),
147 make_tuple(kargs.stride_output, 1),
149 number<1>{});
150 auto view =
151 pad_tensor_view(tmp,
153 sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
154 // 2. we loop over topk 1-1, no need padding
155 return make_tile_window(
157 }();
158
159 Pipeline{}(input_window,
160 output_window,
161 indices_window,
162 kargs.num_rows,
163 kargs.num_experts,
164 kargs.topk,
165 block_row_id);
166 }
167};
168} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition host_utility/hip_check_error.hpp:21
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition topk_softmax_kernel.hpp:16
index_t num_experts
Definition topk_softmax_kernel.hpp:21
index_t topk
Definition topk_softmax_kernel.hpp:22
index_t stride_output
Definition topk_softmax_kernel.hpp:24
const void * p_input
Definition topk_softmax_kernel.hpp:17
index_t num_rows
Definition topk_softmax_kernel.hpp:20
void * p_indices
Definition topk_softmax_kernel.hpp:19
index_t stride_input
Definition topk_softmax_kernel.hpp:23
void * p_output
Definition topk_softmax_kernel.hpp:18
Definition topk_softmax_kernel.hpp:40
const void * p_input
Definition topk_softmax_kernel.hpp:41
void * p_output
Definition topk_softmax_kernel.hpp:42
index_t stride_output
Definition topk_softmax_kernel.hpp:48
index_t stride_input
Definition topk_softmax_kernel.hpp:47
index_t num_rows
Definition topk_softmax_kernel.hpp:44
void * p_indices
Definition topk_softmax_kernel.hpp:43
index_t topk
Definition topk_softmax_kernel.hpp:46
index_t num_experts
Definition topk_softmax_kernel.hpp:45
Definition topk_softmax_kernel.hpp:29
TopkSoftmaxKargs Kargs
Definition topk_softmax_kernel.hpp:51
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition topk_softmax_kernel.hpp:54
static CK_TILE_HOST_DEVICE constexpr auto BlockSize()
Definition topk_softmax_kernel.hpp:90
TopkSoftmaxHostArgs Hargs
Definition topk_softmax_kernel.hpp:52
static constexpr index_t kBlockSize
Definition topk_softmax_kernel.hpp:37
remove_cvref_t< typename Pipeline::Problem > Problem
Definition topk_softmax_kernel.hpp:31
remove_cvref_t< Pipeline_ > Pipeline
Definition topk_softmax_kernel.hpp:30
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition topk_softmax_kernel.hpp:76
typename Problem::InputType InputType
Definition topk_softmax_kernel.hpp:33
typename Problem::WeightType WeightType
Definition topk_softmax_kernel.hpp:34
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition topk_softmax_kernel.hpp:92
typename Problem::IndexType IndexType
Definition topk_softmax_kernel.hpp:35
Definition tile/core/container/sequence.hpp:49