FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ > Struct Template Reference

FmhaBwdOGradDotOKernel&lt; FmhaBwdOGradDotO_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ > Struct Template Reference
ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ > Struct Template Reference

#include <fmha_bwd_kernel.hpp>

Classes

struct  t2s
struct  t2s< float >
struct  t2s< ck_tile::fp16_t >
struct  t2s< ck_tile::bf16_t >
struct  FmhaBwdOGradDotOCommonKargs
struct  FmhaBwdOGradDotOBatchModeKargs
struct  FmhaBwdOGradDotOGroupModeKargs

Public Types

using FmhaBwdOGradDotO = ck_tile::remove_cvref_t<FmhaBwdOGradDotO_>
using DDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::DDataType>
using ODataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::ODataType>
using OGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::OGradDataType>
using Kargs

Public Member Functions

CK_TILE_DEVICE void operator() (Kargs kargs) const

Static Public Member Functions

static CK_TILE_HOST std::string GetName ()
template<bool Cond = !kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_d)
template<bool Cond = kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, const void *seqstart_q_ptr, const void *seqlen_q_ptr, const void *cu_seqlen_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d)
static CK_TILE_HOST constexpr auto GridSize (ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
static CK_TILE_DEVICE constexpr auto GetTileIndex ()
static CK_TILE_HOST dim3 BlockSize ()
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()

Static Public Attributes

static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu
static constexpr ck_tile::index_t kM0 = kBlockSize
static constexpr ck_tile::index_t kVHeaddim = FmhaBwdOGradDotO::kVHeaddim
static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode
static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ
static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV

Member Typedef Documentation

◆ DDataType

template<typename FmhaBwdOGradDotO_>
using ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::DDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::DDataType>

◆ FmhaBwdOGradDotO

template<typename FmhaBwdOGradDotO_>
using ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::FmhaBwdOGradDotO = ck_tile::remove_cvref_t<FmhaBwdOGradDotO_>

◆ Kargs

template<typename FmhaBwdOGradDotO_>
using ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::Kargs
Initial value:
std::
conditional_t<kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs>

◆ ODataType

template<typename FmhaBwdOGradDotO_>
using ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::ODataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::ODataType>

◆ OGradDataType

template<typename FmhaBwdOGradDotO_>
using ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::OGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::OGradDataType>

Member Function Documentation

◆ BlockSize()

template<typename FmhaBwdOGradDotO_>
CK_TILE_HOST dim3 ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::BlockSize ( )
inlinestatic

◆ GetName()

template<typename FmhaBwdOGradDotO_>
CK_TILE_HOST std::string ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::GetName ( )
inlinestatic

◆ GetSmemSize()

template<typename FmhaBwdOGradDotO_>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GetTileIndex()

template<typename FmhaBwdOGradDotO_>
CK_TILE_DEVICE constexpr auto ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::GetTileIndex ( )
inlinestaticconstexpr

◆ GridSize()

template<typename FmhaBwdOGradDotO_>
CK_TILE_HOST constexpr auto ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::GridSize ( ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_ )
inlinestaticconstexpr

◆ MakeKargs() [1/2]

template<typename FmhaBwdOGradDotO_>
template<bool Cond = !kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::MakeKargs ( const void * o_ptr,
const void * do_ptr,
void * d_ptr,
float p_undrop,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t stride_do,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_d,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_d )
inlinestaticconstexpr

◆ MakeKargs() [2/2]

template<typename FmhaBwdOGradDotO_>
template<bool Cond = kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::MakeKargs ( const void * o_ptr,
const void * do_ptr,
void * d_ptr,
float p_undrop,
const void * seqstart_q_ptr,
const void * seqlen_q_ptr,
const void * cu_seqlen_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t stride_do,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_d )
inlinestaticconstexpr

◆ operator()()

template<typename FmhaBwdOGradDotO_>
CK_TILE_DEVICE void ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::operator() ( Kargs kargs) const
inline

Member Data Documentation

◆ kBlockPerCu

template<typename FmhaBwdOGradDotO_>
ck_tile::index_t ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu
staticconstexpr

◆ kBlockSize

template<typename FmhaBwdOGradDotO_>
ck_tile::index_t ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::kBlockSize = FmhaBwdOGradDotO::kBlockSize
staticconstexpr

◆ kIsGroupMode

template<typename FmhaBwdOGradDotO_>
bool ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode
staticconstexpr

◆ kM0

template<typename FmhaBwdOGradDotO_>
ck_tile::index_t ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::kM0 = kBlockSize
staticconstexpr

◆ kPadHeadDimV

template<typename FmhaBwdOGradDotO_>
bool ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV
staticconstexpr

◆ kPadSeqLenQ

template<typename FmhaBwdOGradDotO_>
bool ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ
staticconstexpr

◆ kVHeaddim

template<typename FmhaBwdOGradDotO_>
ck_tile::index_t ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::kVHeaddim = FmhaBwdOGradDotO::kVHeaddim
staticconstexpr

The documentation for this struct was generated from the following file: