FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference

FmhaFwdSplitKVCombineKernel&lt; FmhaPipeline_, EpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference
ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference

#include <fmha_fwd_splitkv_combine_kernel.hpp>

Classes

struct  t2s
struct  t2s< float >
struct  t2s< ck_tile::fp16_t >
struct  t2s< ck_tile::bf16_t >
struct  t2s< ck_tile::fp8_t >
struct  t2s< ck_tile::bf8_t >
struct  EmptyKargs
struct  CommonKargs
struct  CommonLSEKargs
struct  Fp8StaticQuantKargs
struct  BatchModeKargs
struct  GroupModeKargs

Public Types

using FmhaPipeline = remove_cvref_t<FmhaPipeline_>
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>
using LSEDataType = remove_cvref_t<typename FmhaPipeline::LSEDataType>
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>
using ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>

Public Member Functions

CK_TILE_DEVICE void operator() (Kargs kargs) const

Static Public Member Functions

static CK_TILE_HOST std::string GetName ()
template<bool Cond = !kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
template<bool Cond = kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
static CK_TILE_HOST constexpr auto GridSize (ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v)
static CK_TILE_DEVICE constexpr auto GetTileIndex (const Kargs &kargs)
static CK_TILE_HOST dim3 BlockSize ()
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()

Static Public Attributes

static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps
static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize
static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu
static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant

Member Typedef Documentation

◆ EpiloguePipeline

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>

◆ FmhaPipeline

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaPipeline = remove_cvref_t<FmhaPipeline_>

◆ Kargs

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>

◆ LSEDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::LSEDataType = remove_cvref_t<typename FmhaPipeline::LSEDataType>

◆ OaccDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>

◆ ODataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>

Member Function Documentation

◆ BlockSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST dim3 ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::BlockSize ( )
inlinestatic

◆ GetName()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST std::string ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::GetName ( )
inlinestatic

◆ GetSmemSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GetTileIndex()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE constexpr auto ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::GetTileIndex ( const Kargs & kargs)
inlinestaticconstexpr

◆ GridSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr auto ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::GridSize ( ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v )
inlinestaticconstexpr

◆ MakeKargs() [1/2]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = !kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void * lse_acc_ptr,
const void * o_acc_ptr,
void * lse_ptr,
void * o_ptr,
ck_tile::index_t batch,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
float scale_o,
ck_tile::index_t row_stride_o_acc,
ck_tile::index_t row_stride_o,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc )
inlinestaticconstexpr

◆ MakeKargs() [2/2]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void * lse_acc_ptr,
const void * o_acc_ptr,
void * lse_ptr,
void * o_ptr,
ck_tile::index_t batch,
const void * seqstart_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
float scale_o,
ck_tile::index_t row_stride_o_acc,
ck_tile::index_t row_stride_o,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc )
inlinestaticconstexpr

◆ operator()()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE void ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::operator() ( Kargs kargs) const
inline

Member Data Documentation

◆ kBlockPerCu

template<typename FmhaPipeline_, typename EpiloguePipeline_>
index_t ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockPerCu = FmhaPipeline::kBlockPerCu
staticconstexpr

◆ kBlockPerCuInput

template<typename FmhaPipeline_, typename EpiloguePipeline_>
index_t ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu
staticconstexpr

◆ kBlockSize

template<typename FmhaPipeline_, typename EpiloguePipeline_>
index_t ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockSize = FmhaPipeline::kBlockSize
staticconstexpr

◆ kDoFp8StaticQuant

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant
staticconstexpr

◆ kIsGroupMode

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::kIsGroupMode = FmhaPipeline::kIsGroupMode
staticconstexpr

◆ kNumWarps

template<typename FmhaPipeline_, typename EpiloguePipeline_>
index_t ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::kNumWarps = FmhaPipeline::kNumWarps
staticconstexpr

◆ kPadHeadDimV

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadHeadDimV = FmhaPipeline::kPadHeadDimV
staticconstexpr

◆ kPadSeqLenQ

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ
staticconstexpr

◆ kStoreLSE

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVCombineKernel< FmhaPipeline_, EpiloguePipeline_ >::kStoreLSE = FmhaPipeline::kStoreLSE
staticconstexpr

The documentation for this struct was generated from the following file: