Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ > Struct Template Reference

Rmsnorm2dFwdPipelineTwoPass&lt; Problem_, Policy_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ > Struct Template Reference
ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ > Struct Template Reference

#include <rmsnorm2d_fwd_pipeline_two_pass.hpp>

Public Types

using Problem = ck_tile::remove_cvref_t<Problem_>
using Policy = ck_tile::remove_cvref_t<Policy_>
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>
using XResidualDataType = XDataType
using YResidualDataType = XDataType

Public Member Functions

template<typename XWindow, typename XResidualWindow, typename GammaWindow, typename YWindow, typename YResidualWindow, typename InvRmsWindow, typename SmoothScaleWindow, typename YScaleWindow, typename UnquantYWindow, typename Epilogue>
CK_TILE_DEVICE auto operator() (const XWindow &x_window_, const XResidualWindow &x_residual_window_, const GammaWindow &gamma_window_, YWindow &y_window, const YResidualWindow &y_residual_window_, InvRmsWindow &inv_rms_window, const SmoothScaleWindow &, YScaleWindow &, UnquantYWindow &, ComputeDataType epsilon, ck_tile::index_t row_size, void *smem, Epilogue) const

Static Public Member Functions

static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize ()

Static Public Attributes

static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync
static constexpr bool kPadM = false
static constexpr bool kPadN = Problem::Traits::kPadN
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant
static constexpr const char * name

Member Typedef Documentation

◆ ComputeDataType

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
using ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>

◆ GammaDataType

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
using ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>

◆ InvRmsDataType

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
using ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>

◆ Policy

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
using ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::Policy = ck_tile::remove_cvref_t<Policy_>

◆ Problem

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
using ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::Problem = ck_tile::remove_cvref_t<Problem_>

◆ XDataType

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
using ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>

◆ XResidualDataType

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
using ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::XResidualDataType = XDataType

◆ YDataType

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
using ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>

◆ YResidualDataType

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
using ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::YResidualDataType = XDataType

Member Function Documentation

◆ GetSmemSize()

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
CK_TILE_HOST_DEVICE constexpr index_t ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ operator()()

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
template<typename XWindow, typename XResidualWindow, typename GammaWindow, typename YWindow, typename YResidualWindow, typename InvRmsWindow, typename SmoothScaleWindow, typename YScaleWindow, typename UnquantYWindow, typename Epilogue>
CK_TILE_DEVICE auto ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::operator() ( const XWindow & x_window_,
const XResidualWindow & x_residual_window_,
const GammaWindow & gamma_window_,
YWindow & y_window,
const YResidualWindow & y_residual_window_,
InvRmsWindow & inv_rms_window,
const SmoothScaleWindow & ,
YScaleWindow & ,
UnquantYWindow & ,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void * smem,
Epilogue  ) const
inline

Member Data Documentation

◆ kFusedAdd

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
auto ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::kFusedAdd = Problem::Traits::kFusedAdd
staticconstexpr

◆ kFusedQuant

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
auto ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::kFusedQuant = Problem::Traits::kFusedQuant
staticconstexpr

◆ kHasGamma

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
bool ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>
staticconstexpr

◆ kNeedCrossWarpSync

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
bool ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::kNeedCrossWarpSync = Problem::kNeedCrossWarpSync
staticconstexpr

◆ kPadM

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
bool ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::kPadM = false
staticconstexpr

◆ kPadN

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
bool ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::kPadN = Problem::Traits::kPadN
staticconstexpr

◆ kSaveInvRms

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
bool ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::kSaveInvRms = Problem::Traits::kSaveInvRms
staticconstexpr

◆ name

template<typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
const char* ck_tile::Rmsnorm2dFwdPipelineTwoPass< Problem_, Policy_ >::name
staticconstexpr
Initial value:
= []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_tp";
else
return "wpr_tp";
}()
static constexpr bool kNeedCrossWarpSync
Definition add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp:30

The documentation for this struct was generated from the following file: