CShuffleEpilogue< Problem_, Policy_ > Struct Template Reference

CShuffleEpilogue&lt; Problem_, Policy_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::CShuffleEpilogue< Problem_, Policy_ > Struct Template Reference
ck_tile::CShuffleEpilogue< Problem_, Policy_ > Struct Template Reference

#include <cshuffle_epilogue.hpp>

Classes

struct  EmptyScale
struct  ScaleDataType
struct  ScaleDataType< T, std::void_t< typename T::DataType > >

Public Types

using Problem = remove_cvref_t<Problem_>
using AsDataType = remove_cvref_t<typename Problem::AsDataType>
using BsDataType = remove_cvref_t<typename Problem::BsDataType>
using AccDataType = remove_cvref_t<typename Problem::AccDataType>
using ODataType = remove_cvref_t<typename Problem::ODataType>
using DsDataType = remove_cvref_t<typename Problem::DsDataType>
using DsLayout = remove_cvref_t<typename Problem::DsLayout>
using AsDataTypeTuple
using BsDataTypeTuple
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>
using ATypeToUse
using BTypeToUse
using ELayout = remove_cvref_t<typename Problem::ELayout>
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>
using WG
using CWarpDstr = typename WG::CWarpDstr
using CWarpTensor = typename WG::CWarpTensor
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding
using SFC

Public Member Functions

CK_TILE_DEVICE CShuffleEpilogue (CDElementwise elfunc=CDElementwise{})
template<index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
CK_TILE_DEVICE void scale_tile (LdsTile &lds_tile, ScaleM &scale_m_window, ScaleN &scale_n_window)
template<index_t iAccess, typename OAccTile, typename LdsTile>
CK_TILE_DEVICE void slice_acc_tile (const OAccTile &o_acc_tile, LdsTile &lds_tile)
template<typename LdsTile, typename InLdsWindow>
CK_TILE_DEVICE void cast_lds_tile (LdsTile &lds_tile, InLdsWindow &in_lds_window)
template<typename DramWindows, typename COutTensor>
CK_TILE_DEVICE void apply_d_tensors (DramWindows &d_dram_windows, COutTensor &c_out_tensor)
template<typename OutDramWindow, typename COutTensor>
CK_TILE_DEVICE void store_to_dram (OutDramWindow &out_dram_window, const COutTensor &c_out_tensor)
template<index_t iAccess, typename OutDramWindow, typename DDramWindows>
CK_TILE_DEVICE void move_windows (OutDramWindow &out_dram_window, DDramWindows &d_dram_windows)
 Move both the output and D tensors windows for the next access.
template<typename ODramWindow, typename OAccTile, typename DsDramWindows, typename ScaleM = EmptyScale, typename ScaleN = EmptyScale, int EnablePermuateN_ = TiledMMAPermuteN, std::enable_if_t< EnablePermuateN_, int > = 0>
CK_TILE_DEVICE auto operator() (ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *, const ScaleM &scale_m={}, const ScaleN &scale_n={})
template<typename ODramWindow, typename OAccTile, typename DsDramWindows, typename ScaleM = EmptyScale, typename ScaleN = EmptyScale, int EnablePermuateN_ = TiledMMAPermuteN, std::enable_if_t<!EnablePermuateN_, int > = 0>
CK_TILE_DEVICE auto operator() (ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *p_smem, const ScaleM &scale_m={}, const ScaleN &scale_n={})

Static Public Member Functions

static CK_TILE_HOST const std::string GetName ()
static CK_TILE_HOST_DEVICE constexpr index_t GetVectorSizeC ()
 Get the vector store size for C tensor.
template<index_t I>
static CK_TILE_HOST_DEVICE constexpr index_t GetVectorSizeD (number< I > index)
 Get the vector store size for Di tensor.
template<typename Problem>
static CK_TILE_HOST_DEVICE constexpr auto MakeLdsBlockDescriptor ()
static CK_TILE_DEVICE constexpr auto MakeLdsDistributionEncode ()
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize ()

Public Attributes

CDElementwise elfunc_

Static Public Attributes

static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation
static constexpr index_t kBlockSize = Problem::kBlockSize
static constexpr index_t kMPerBlock = Problem::kMPerBlock
static constexpr index_t kNPerBlock = Problem::kNPerBlock
static constexpr index_t MWave = Problem::MWave
static constexpr index_t NWave = Problem::NWave
static constexpr index_t MPerXdl = Problem::MPerXdl
static constexpr index_t NPerXdl = Problem::NPerXdl
static constexpr index_t KPerXdl = Problem::KPerXdl
static constexpr index_t isCTransposed = Problem::isCTransposed
static constexpr bool FixedVectorSize = Problem::FixedVectorSize
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp
static constexpr index_t VectorSizeC = Problem::VectorSizeC
static constexpr index_t MPerIteration = MPerXdl * MWave
static constexpr index_t NPerIteration = NPerXdl * NWave
static constexpr index_t NumDTensor = Problem::NumDTensor
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave)
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave)
static constexpr auto shuffle_tile_tuple
 Shuffle tile configuration parameters.
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple)
static constexpr index_t NumNXdlPerWavePerShuffle
static constexpr auto MNPerIterationShuffle
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle)
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle)

Member Typedef Documentation

◆ AccDataType

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::AccDataType = remove_cvref_t<typename Problem::AccDataType>

◆ ADataType

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>

◆ AsDataType

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::AsDataType = remove_cvref_t<typename Problem::AsDataType>

◆ AsDataTypeTuple

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::AsDataTypeTuple
Initial value:
std::conditional_t<ADataTypeIsTuple,
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
static constexpr bool ADataTypeIsTuple
Definition cshuffle_epilogue.hpp:81

◆ ATypeToUse

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::ATypeToUse
Initial value:
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition cshuffle_epilogue.hpp:92
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition cshuffle_epilogue.hpp:93

◆ BDataType

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>

◆ BsDataType

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::BsDataType = remove_cvref_t<typename Problem::BsDataType>

◆ BsDataTypeTuple

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::BsDataTypeTuple
Initial value:
std::conditional_t<BDataTypeIsTuple,
static constexpr bool BDataTypeIsTuple
Definition cshuffle_epilogue.hpp:82

◆ BTypeToUse

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::BTypeToUse
Initial value:
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>

◆ CDElementwise

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::CDElementwise = remove_cvref_t<typename Problem::CDElementwise>

◆ CWarpDstr

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::CWarpDstr = typename WG::CWarpDstr

◆ CWarpDstrEncoding

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::CWarpDstrEncoding = typename WG::CWarpDstrEncoding

◆ CWarpTensor

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::CWarpTensor = typename WG::CWarpTensor

◆ DsDataType

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::DsDataType = remove_cvref_t<typename Problem::DsDataType>

◆ DsLayout

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::DsLayout = remove_cvref_t<typename Problem::DsLayout>

◆ ELayout

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::ELayout = remove_cvref_t<typename Problem::ELayout>

◆ ODataType

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType>

◆ Problem

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_>

◆ SFC

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::SFC

◆ WG

template<typename Problem_, typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::WG
Initial value:
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
static constexpr index_t MPerXdl
Definition cshuffle_epilogue.hpp:108
static constexpr index_t isCTransposed
Definition cshuffle_epilogue.hpp:111
std::conditional_t< std::is_same_v< ADataType, pk_int4_t >, BDataType, ADataType > ATypeToUse
Definition cshuffle_epilogue.hpp:95
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition cshuffle_epilogue.hpp:76
static constexpr index_t KPerXdl
Definition cshuffle_epilogue.hpp:110
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition cshuffle_epilogue.hpp:98
static constexpr index_t NPerXdl
Definition cshuffle_epilogue.hpp:109

Constructor & Destructor Documentation

◆ CShuffleEpilogue()

template<typename Problem_, typename Policy_ = void>
CK_TILE_DEVICE ck_tile::CShuffleEpilogue< Problem_, Policy_ >::CShuffleEpilogue ( CDElementwise elfunc = CDElementwise{})
inline

Member Function Documentation

◆ apply_d_tensors()

template<typename Problem_, typename Policy_ = void>
template<typename DramWindows, typename COutTensor>
CK_TILE_DEVICE void ck_tile::CShuffleEpilogue< Problem_, Policy_ >::apply_d_tensors ( DramWindows & d_dram_windows,
COutTensor & c_out_tensor )
inline

◆ cast_lds_tile()

template<typename Problem_, typename Policy_ = void>
template<typename LdsTile, typename InLdsWindow>
CK_TILE_DEVICE void ck_tile::CShuffleEpilogue< Problem_, Policy_ >::cast_lds_tile ( LdsTile & lds_tile,
InLdsWindow & in_lds_window )
inline

◆ GetName()

template<typename Problem_, typename Policy_ = void>
CK_TILE_HOST const std::string ck_tile::CShuffleEpilogue< Problem_, Policy_ >::GetName ( )
inlinestaticnodiscard

◆ GetSmemSize()

template<typename Problem_, typename Policy_ = void>
CK_TILE_HOST_DEVICE constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GetVectorSizeC()

template<typename Problem_, typename Policy_ = void>
CK_TILE_HOST_DEVICE constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::GetVectorSizeC ( )
inlinestaticconstexpr

Get the vector store size for C tensor.

Note
The vector store size for output C tensor would depend on multiple factors like its data layout and warp gemm C transposition. In general it would be the number of consecutive elements in contiguous C dimension hold by single thread.
Returns
The vector store size for C tensor.

◆ GetVectorSizeD()

template<typename Problem_, typename Policy_ = void>
template<index_t I>
CK_TILE_HOST_DEVICE constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::GetVectorSizeD ( number< I > index)
inlinestaticconstexpr

Get the vector store size for Di tensor.

Returns
The vector store size for Di tensor.

◆ MakeLdsBlockDescriptor()

template<typename Problem_, typename Policy_ = void>
template<typename Problem>
CK_TILE_HOST_DEVICE constexpr auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MakeLdsBlockDescriptor ( )
inlinestaticconstexpr

◆ MakeLdsDistributionEncode()

template<typename Problem_, typename Policy_ = void>
CK_TILE_DEVICE constexpr auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MakeLdsDistributionEncode ( )
inlinestaticconstexpr

◆ move_windows()

template<typename Problem_, typename Policy_ = void>
template<index_t iAccess, typename OutDramWindow, typename DDramWindows>
CK_TILE_DEVICE void ck_tile::CShuffleEpilogue< Problem_, Policy_ >::move_windows ( OutDramWindow & out_dram_window,
DDramWindows & d_dram_windows )
inline

Move both the output and D tensors windows for the next access.

◆ operator()() [1/2]

template<typename Problem_, typename Policy_ = void>
template<typename ODramWindow, typename OAccTile, typename DsDramWindows, typename ScaleM = EmptyScale, typename ScaleN = EmptyScale, int EnablePermuateN_ = TiledMMAPermuteN, std::enable_if_t< EnablePermuateN_, int > = 0>
CK_TILE_DEVICE auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::operator() ( ODramWindow & out_dram_window,
const OAccTile & o_acc_tile,
const DsDramWindows & ds_dram_windows,
void * ,
const ScaleM & scale_m = {},
const ScaleN & scale_n = {} )
inline

◆ operator()() [2/2]

template<typename Problem_, typename Policy_ = void>
template<typename ODramWindow, typename OAccTile, typename DsDramWindows, typename ScaleM = EmptyScale, typename ScaleN = EmptyScale, int EnablePermuateN_ = TiledMMAPermuteN, std::enable_if_t<!EnablePermuateN_, int > = 0>
CK_TILE_DEVICE auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::operator() ( ODramWindow & out_dram_window,
const OAccTile & o_acc_tile,
const DsDramWindows & ds_dram_windows,
void * p_smem,
const ScaleM & scale_m = {},
const ScaleN & scale_n = {} )
inline

◆ scale_tile()

template<typename Problem_, typename Policy_ = void>
template<index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
CK_TILE_DEVICE void ck_tile::CShuffleEpilogue< Problem_, Policy_ >::scale_tile ( LdsTile & lds_tile,
ScaleM & scale_m_window,
ScaleN & scale_n_window )
inline

◆ slice_acc_tile()

template<typename Problem_, typename Policy_ = void>
template<index_t iAccess, typename OAccTile, typename LdsTile>
CK_TILE_DEVICE void ck_tile::CShuffleEpilogue< Problem_, Policy_ >::slice_acc_tile ( const OAccTile & o_acc_tile,
LdsTile & lds_tile )
inline

◆ store_to_dram()

template<typename Problem_, typename Policy_ = void>
template<typename OutDramWindow, typename COutTensor>
CK_TILE_DEVICE void ck_tile::CShuffleEpilogue< Problem_, Policy_ >::store_to_dram ( OutDramWindow & out_dram_window,
const COutTensor & c_out_tensor )
inline

Member Data Documentation

◆ ADataTypeIsTuple

template<typename Problem_, typename Policy_ = void>
bool ck_tile::CShuffleEpilogue< Problem_, Policy_ >::ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value
staticconstexpr

◆ BDataTypeIsTuple

template<typename Problem_, typename Policy_ = void>
bool ck_tile::CShuffleEpilogue< Problem_, Policy_ >::BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value
staticconstexpr

◆ BlockedXDLN_PerWarp

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp
staticconstexpr

◆ elfunc_

template<typename Problem_, typename Policy_ = void>
CDElementwise ck_tile::CShuffleEpilogue< Problem_, Policy_ >::elfunc_

◆ FixedVectorSize

template<typename Problem_, typename Policy_ = void>
bool ck_tile::CShuffleEpilogue< Problem_, Policy_ >::FixedVectorSize = Problem::FixedVectorSize
staticconstexpr

◆ isCTransposed

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::isCTransposed = Problem::isCTransposed
staticconstexpr

◆ kBlockSize

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::kBlockSize = Problem::kBlockSize
staticconstexpr

◆ kMPerBlock

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::kMPerBlock = Problem::kMPerBlock
staticconstexpr

◆ kNPerBlock

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::kNPerBlock = Problem::kNPerBlock
staticconstexpr

◆ KPerXdl

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::KPerXdl = Problem::KPerXdl
staticconstexpr

◆ MemoryOperation

template<typename Problem_, typename Policy_ = void>
memory_operation_enum ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MemoryOperation = Problem::MemoryOperation
staticconstexpr

◆ MNPerIterationShuffle

template<typename Problem_, typename Policy_ = void>
auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MNPerIterationShuffle
staticconstexpr
Initial value:
= [] {
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
else
return std::make_tuple(m_val, n_val);
}()
int32_t index_t
Definition integer.hpp:9
static constexpr index_t kNPerBlock
Definition cshuffle_epilogue.hpp:105
static constexpr index_t MWave
Definition cshuffle_epilogue.hpp:106
static constexpr index_t NumMXdlPerWavePerShuffle
Definition cshuffle_epilogue.hpp:236
static constexpr index_t NumNXdlPerWavePerShuffle
Definition cshuffle_epilogue.hpp:237
static constexpr index_t NWave
Definition cshuffle_epilogue.hpp:107
static constexpr index_t kMPerBlock
Definition cshuffle_epilogue.hpp:104

◆ MPerIteration

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MPerIteration = MPerXdl * MWave
staticconstexpr

◆ MPerIterationShuffle

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MPerIterationShuffle = std::get<0>(MNPerIterationShuffle)
staticconstexpr

◆ MPerXdl

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MPerXdl = Problem::MPerXdl
staticconstexpr

◆ MRepeat

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MRepeat = kMPerBlock / (MPerXdl * MWave)
staticconstexpr

◆ MWave

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MWave = Problem::MWave
staticconstexpr

◆ NPerIteration

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NPerIteration = NPerXdl * NWave
staticconstexpr

◆ NPerIterationShuffle

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NPerIterationShuffle = std::get<1>(MNPerIterationShuffle)
staticconstexpr

◆ NPerXdl

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NPerXdl = Problem::NPerXdl
staticconstexpr

◆ NRepeat

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NRepeat = kNPerBlock / (NPerXdl * NWave)
staticconstexpr

◆ NumDTensor

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NumDTensor = Problem::NumDTensor
staticconstexpr

◆ NumMXdlPerWavePerShuffle

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple)
staticconstexpr

◆ NumNXdlPerWavePerShuffle

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NumNXdlPerWavePerShuffle
staticconstexpr
Initial value:
=
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
static constexpr index_t BlockedXDLN_PerWarp
Definition cshuffle_epilogue.hpp:114
static constexpr auto shuffle_tile_tuple
Shuffle tile configuration parameters.
Definition cshuffle_epilogue.hpp:209

◆ NWave

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NWave = Problem::NWave
staticconstexpr

◆ shuffle_tile_tuple

template<typename Problem_, typename Policy_ = void>
auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::shuffle_tile_tuple
staticconstexpr
Initial value:
= [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread >= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
}
else
{
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
}
}
}()
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
static CK_TILE_HOST_DEVICE constexpr index_t GetVectorSizeC()
Get the vector store size for C tensor.
Definition cshuffle_epilogue.hpp:151

Shuffle tile configuration parameters.

These parameters control the number of XDL tiles processed per wave in each shuffle iteration:

  • NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
  • NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave

◆ TiledMMAPermuteN

template<typename Problem_, typename Policy_ = void>
bool ck_tile::CShuffleEpilogue< Problem_, Policy_ >::TiledMMAPermuteN = Problem::TiledMMAPermuteN
staticconstexpr

◆ VectorSizeC

template<typename Problem_, typename Policy_ = void>
index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::VectorSizeC = Problem::VectorSizeC
staticconstexpr

The documentation for this struct was generated from the following file: