15template <
typename Problem_,
typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
37 static_assert(
kQLoadOnce == Policy::QLoadOnce);
49 static_assert(
kSubQKHeaddim <= 256,
"hdim bigger than 256 is not suitable for this pipeline!");
57 static constexpr auto BiasEnum = Problem::BiasEnum;
58 static constexpr bool kStoreLSE = Problem::kStoreLSE;
72 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
74 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
76 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
77 return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
79 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
83 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
85 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
88 if constexpr(Problem::kBlockPerCu != -1)
89 return Problem::kBlockPerCu;
118 static constexpr const char*
name =
"qr";
120 using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
127 template <
typename QDramBlockWindowTmp,
128 typename KDramBlockWindowTmp,
129 typename VDramBlockWindowTmp,
130 typename BiasDramBlockWindowTmp,
131 typename RandValDramBlockWindowTmp,
132 typename LSEDramBlockWindowTmp,
133 typename QElementFunction,
134 typename KElementFunction,
135 typename VElementFunction,
136 typename BiasElementFunction,
137 typename LSEElementFunction,
138 typename SAccElementFunction,
139 typename PComputeElementFunction,
140 typename OAccElementFunction,
141 typename PositionEncoding,
142 typename AttentionVariantParams,
143 typename BlockIndices>
145 operator()(
const QDramBlockWindowTmp& q_dram_block_window_tmp,
146 const QElementFunction& q_element_func,
147 const KDramBlockWindowTmp& k_dram_block_window_tmp,
148 const KElementFunction& k_element_func,
149 const VDramBlockWindowTmp& v_dram_block_window_tmp,
150 const VElementFunction& v_element_func,
151 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
152 const BiasElementFunction& bias_element_func,
153 RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
154 LSEDramBlockWindowTmp& lse_dram_window_tmp,
155 const LSEElementFunction& lse_element_func,
156 const SAccElementFunction& s_acc_element_func,
157 const PComputeElementFunction& p_compute_element_func,
158 const OAccElementFunction& o_acc_element_func,
160 PositionEncoding position_encoding,
163 const AttentionVariantParams& variant_params,
164 const BlockIndices& block_indices,
169 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
170 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
171 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
174 static_assert(
kM0 == QDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
175 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
176 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
177 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
178 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
179 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
180 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}],
185 static_cast<char*
>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
187 k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
194 Policy::template MakeVLdsBlockDescriptor<Problem>());
196 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
199 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
200 constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
202 auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
203 q_dram_block_window_tmp.get_window_lengths(),
204 q_dram_block_window_tmp.get_window_origin(),
205 Policy::template MakeQRegTileDistribution<Problem>());
209 using SaccBlockTileType =
decltype(gemm_0.MakeCBlockTile());
210 auto s_acc = SaccBlockTileType{};
213 const auto f_max = [](
auto e0,
auto e1) {
return max(e0, e1); };
214 const auto f_sum = [](
auto e0,
auto e1) {
return e0 + e1; };
222 using OaccBlockTileType =
decltype(gemm_1.MakeCBlockTile());
225 auto o_acc = OaccBlockTileType{};
226 auto m = MLBlockTileType{};
227 auto l = MLBlockTileType{};
233 const auto q_origin = q_dram_window.get_window_origin();
234 const auto [seqlen_k_start, seqlen_k_end] =
242 if(num_total_loop <= 0)
260 auto k_dram_block_window =
262 k_dram_block_window_tmp.get_window_lengths(),
263 {seqlen_k_start, 0});
265 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
266 auto bias_dram_window =
268 bias_dram_block_window_tmp.get_window_lengths(),
269 {bias_origin.at(number<0>{}), seqlen_k_start},
270 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
272 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
273 randval_dram_block_window_tmp, seqlen_k_start);
277 v_dram_block_window_tmp.get_window_lengths(),
279 Policy::template MakeVDramTileDistribution<Problem>());
289 auto schedule_gemm0 = [] {
291 constexpr auto WarpGemmConfig =
292 BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
293 using WarpGemm0 =
remove_cvref_t<
decltype(WarpGemmConfig.template at<0>())>;
294 constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
295 constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
296 constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
297 constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
298 constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
299 constexpr index_t NumMfmaInsts = (
kM0 / WarpGemm0M) * (
kN0 / WarpGemm0N) *
300 (
kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
303 static_assert(NumMfmaInsts % 8 == 0);
304 static_for<0, NumMfmaInsts / 8, 1>{}([&](
auto) {
305 __builtin_amdgcn_sched_group_barrier(
DS_READ, 2, 0);
306 __builtin_amdgcn_sched_group_barrier(
MFMA, 2, 0);
307 __builtin_amdgcn_sched_group_barrier(
DS_READ, 1, 0);
308 __builtin_amdgcn_sched_group_barrier(
MFMA, 2, 0);
309 __builtin_amdgcn_sched_group_barrier(
DS_READ, 1, 0);
310 __builtin_amdgcn_sched_group_barrier(
MFMA, 4, 0);
315 static_assert(2 <= k0_loops);
316 static_assert(1 <= k1_loops);
321 k_dram_block_window.get_bottom_tensor_view(),
322 k_dram_block_window.get_window_lengths(),
323 k_dram_block_window.get_window_origin(),
324 Policy::template MakeKDramTileDistribution<Problem>());
327 auto k_block_tile =
load_tile(k_dram_window);
337 __builtin_amdgcn_sched_barrier(
340 const auto bias_tile =
load_tile(bias_dram_window);
343 __builtin_amdgcn_sched_barrier(
347 if constexpr(k0_loops > 2)
349 static_for<0, k0_loops - 2, 1>{}([&](
auto i_k0) {
353 sequence<0, i_k0 * kK0>{},
354 sequence<
kM0, (i_k0 + 1) *
kK0>{}),
367 const auto v_prefetch =
load_tile(v_dram_window);
372 sequence<0, (k0_loops - 2) *
kK0>{},
373 sequence<
kM0, (k0_loops - 1) *
kK0>{}),
383 sequence<0, (k0_loops - 1) *
kK0>{},
384 sequence<kM0, k0_loops * kK0>{}),
395 [&](
auto& x,
const auto& y) {
396#if !CK_TILE_FMHA_FWD_FAST_EXP2
408 const auto k_origin = k_dram_block_window.get_window_origin();
409 constexpr auto s_spans =
decltype(s_acc)::get_distributed_spans();
414 s_acc.get_tile_distribution(),
make_tuple(idx0, idx1));
418 constexpr auto i_j_idx =
make_tuple(idx0, idx1);
420 s_acc(i_j_idx) *= scale_s;
421 position_encoding.update(s_acc(i_j_idx), row, col);
430 auto apply_logits_transform =
431 [&variant, &variant_params, &block_indices](
auto& x) {
432 x = variant.LogitsTransform(variant_params,
433 variant.QueryTransform(variant_params, x),
434 block_indices.batch_idx,
435 block_indices.qo_head_idx,
436 block_indices.kv_head_idx);
438#if !CK_TILE_FMHA_FWD_FAST_EXP2
446#if !CK_TILE_FMHA_FWD_FAST_EXP2
454 const auto k_origin = k_dram_block_window.get_window_origin();
455 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(
number<0>{}),
459 if(need_perpixel_check)
465 return !variant.LogitsMask(variant_params,
466 block_indices.batch_idx,
469 block_indices.qo_head_idx,
470 block_indices.kv_head_idx);
483 const auto m_old = m;
485 [](
auto& e0,
auto e1,
auto e2) { e0 =
max(e1, e2); }, m, m_old, m_local);
488 s.get_tile_distribution());
506 constexpr auto p_spans =
decltype(p_compute)::get_distributed_spans();
509#if CK_TILE_FMHA_FWD_FAST_EXP2
510 auto row_max = scale_s * get_validated_m(m[i_idx]);
513 constexpr auto i_j_idx =
make_tuple(idx0, idx1);
514#if CK_TILE_FMHA_FWD_FAST_EXP2
518 p_compute(i_j_idx) =
exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
524 p_compute(i_j_idx) =
exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
528 p_compute(i_j_idx) =
exp2(scale_s * s[i_j_idx] - row_max);
532 p_compute(i_j_idx) =
exp(s[i_j_idx] - get_validated_m(m[i_idx]));
542 constexpr auto o_spans =
decltype(o_acc)::get_distributed_spans();
545#if CK_TILE_FMHA_FWD_FAST_EXP2
546 const auto tmp = [&]() {
550 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
557 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
561 auto row_max = scale_s * get_validated_m(m[i_idx]);
562 return exp2(scale_s * m_old[i_idx] - row_max);
567 const auto tmp =
exp(m_old[i_idx] - get_validated_m(m[i_idx]));
569 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
571 constexpr auto i_j_idx =
make_tuple(idx0, idx1);
575 o_acc(i_j_idx) *= tmp;
584 dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
585 smem_ptr, seqlen_k_start + i_total_loops *
kN0, p_compute, randval_dram_window);
589 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
592 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
609 if constexpr(k1_loops > 1)
611 static_for<0, k1_loops - 1, 1>{}([&](
auto i_k1) {
616 p, sequence<0, i_k1 * kK1>{}, sequence<
kM0, (i_k1 + 1) *
kK1>{}),
619 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
622 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
646 }
while(++i_total_loops < num_total_loop);
653 constexpr auto lse_spans =
decltype(lse)::get_distributed_spans();
656#if CK_TILE_FMHA_FWD_FAST_EXP2
660 lse(i_idx) = m_[i_idx] /
C_LOG2E +
log(l_[i_idx]);
666 lse(i_idx) = m_[i_idx] /
C_LOG2E +
log(l_[i_idx]);
670 lse(i_idx) = m_[i_idx] * scale_s /
C_LOG2E +
log(l_[i_idx]);
674 lse(i_idx) = m_[i_idx] +
log(l_[i_idx]);
682 constexpr auto o_spans =
decltype(o_acc)::get_distributed_spans();
686 const auto tmp = [&]() {
687 if constexpr(FmhaMask::IsMasking)
689 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
695 constexpr auto i_j_idx =
make_tuple(idx0, idx1);
696 o_acc(i_j_idx) *= tmp;
705 template <
typename QDramBlockWindowTmp,
706 typename KDramBlockWindowTmp,
707 typename VDramBlockWindowTmp,
708 typename BiasDramBlockWindowTmp,
709 typename RandValDramBlockWindowTmp,
710 typename LSEDramBlockWindowTmp,
711 typename PositionEncoding,
712 typename AttentionVariantParams,
713 typename BlockIndices>
715 operator()(
const QDramBlockWindowTmp& q_dram_block_window_tmp,
716 const KDramBlockWindowTmp& k_dram_block_window_tmp,
717 const VDramBlockWindowTmp& v_dram_block_window_tmp,
718 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
719 RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
720 LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
722 PositionEncoding position_encoding,
725 const AttentionVariantParams& variant_params,
726 const BlockIndices& block_indices,
732 k_dram_block_window_tmp,
734 v_dram_block_window_tmp,
736 bias_dram_block_window_tmp,
738 randval_dram_block_window_tmp,
739 lse_dram_block_window_tmp,
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
unsigned int uint32_t
Definition stdint.h:126
Definition block_fmha_pipeline_qr_ks_vs.hpp:17
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:28
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:24
static constexpr index_t kN1
Definition block_fmha_pipeline_qr_ks_vs.hpp:44
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qr_ks_vs.hpp:715
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:29
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_qr_ks_vs.hpp:53
static constexpr index_t kSubQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs.hpp:47
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:25
remove_cvref_t< Policy_ > Policy
Definition block_fmha_pipeline_qr_ks_vs.hpp:19
static constexpr index_t kAlignmentK
Definition block_fmha_pipeline_qr_ks_vs.hpp:73
static constexpr bool kQLoadOnce
Definition block_fmha_pipeline_qr_ks_vs.hpp:36
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_pipeline_qr_ks_vs.hpp:31
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:23
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:20
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:27
static constexpr index_t kAlignmentQ
Definition block_fmha_pipeline_qr_ks_vs.hpp:71
static constexpr bool kHasDropout
Definition block_fmha_pipeline_qr_ks_vs.hpp:59
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_qr_ks_vs.hpp:52
static constexpr index_t kN0
Definition block_fmha_pipeline_qr_ks_vs.hpp:42
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs.hpp:57
static constexpr index_t kAlignmentV
Definition block_fmha_pipeline_qr_ks_vs.hpp:75
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_qr_ks_vs.hpp:87
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:26
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:22
static constexpr index_t kK0
Definition block_fmha_pipeline_qr_ks_vs.hpp:43
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_qr_ks_vs.hpp:58
static constexpr index_t kAlignmentO
Definition block_fmha_pipeline_qr_ks_vs.hpp:82
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs.hpp:46
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_qr_ks_vs.hpp:51
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs.hpp:122
static constexpr index_t kM0
Definition block_fmha_pipeline_qr_ks_vs.hpp:41
static constexpr const char * name
Definition block_fmha_pipeline_qr_ks_vs.hpp:118
static constexpr index_t kK1
Definition block_fmha_pipeline_qr_ks_vs.hpp:45
static constexpr uint32_t MFMA
Definition block_fmha_pipeline_qr_ks_vs.hpp:62
remove_cvref_t< Problem_ > Problem
Definition block_fmha_pipeline_qr_ks_vs.hpp:18
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_qr_ks_vs.hpp:55
static constexpr uint32_t DS_READ
Definition block_fmha_pipeline_qr_ks_vs.hpp:61
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_pipeline_qr_ks_vs.hpp:32
static constexpr index_t kAlignmentBias
Definition block_fmha_pipeline_qr_ks_vs.hpp:84
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_pipeline_qr_ks_vs.hpp:35
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_qr_ks_vs.hpp:39
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:30
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:21
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition block_fmha_pipeline_qr_ks_vs.hpp:120
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qr_ks_vs.hpp:145
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_pipeline_qr_ks_vs.hpp:34
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_qr_ks_vs.hpp:56
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_qr_ks_vs.hpp:54
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469