Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 105 additions & 6 deletions include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,118 @@ struct CShuffleEpilogue
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
constexpr auto DataTypeSize = sizeof(ODataType);
constexpr index_t VectorLen = GetVectorSizeC();
constexpr index_t banks = get_n_lds_banks();

constexpr index_t BytesPerBank = 4;

// N is contiguous dimension
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
constexpr index_t MLdsLayerRequired =
banks * BytesPerBank / NPerIterationShuffle / DataTypeSize;
constexpr auto MLdsLayer = max(1, MLdsLayerRequired);

constexpr index_t BaseStrideElems = NPerIterationShuffle * MLdsLayer;
static_assert((BaseStrideElems * DataTypeSize) % BytesPerBank == 0,
"LDS row stride must be 4B-aligned for bank-word padding logic");
// calculate how many elements to pad to avoid bank conflict
#if defined(__gfx950__)
constexpr index_t ElemsPer4B = BytesPerBank / ck_tile::gcd(BytesPerBank, DataTypeSize);
constexpr auto ToWords = [](index_t elems) constexpr {
return (elems * DataTypeSize) / BytesPerBank;
};
constexpr index_t BaseWords = ToWords(BaseStrideElems);
constexpr index_t PadWords = ((BaseWords % 2) == 0) ? 1 : 0;
constexpr auto PaddingAmount = PadWords * ElemsPer4B;
#else
constexpr auto PaddingAmount = 0;
#endif

constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle / MLdsLayer>{},
number<NPerIterationShuffle / VectorLen * MLdsLayer>{},
number<VectorLen>{}),
make_tuple(number<NPerIterationShuffle * MLdsLayer + PaddingAmount>{},
number<VectorLen>{},
number<1>{}),
number<VectorLen>{},
number<1>{});

constexpr auto lds_block_desc_1 = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<MPerIterationShuffle / MLdsLayer>{}),
make_unmerge_transform(make_tuple(
number<MLdsLayer>{}, number<NPerIterationShuffle / VectorLen>{})),
make_pass_through_transform(number<VectorLen>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));

constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
number<MPerIterationShuffle / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<NPerIterationShuffle / VectorLen>{}, number<VectorLen>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));

return lds_block_desc;
}
// M is contiguous dimension
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
constexpr index_t NLdsLayerRequired =
get_n_lds_banks() * BytesPerBank / MPerIterationShuffle / DataTypeSize;
constexpr auto NLdsLayer = max(1, NLdsLayerRequired);

constexpr index_t BaseStrideElems = MPerIterationShuffle * NLdsLayer;

static_assert((BaseStrideElems * DataTypeSize) % BytesPerBank == 0,
"LDS row stride must be 4B-aligned for bank-word padding logic");

#if defined(__gfx950__)
constexpr index_t ElemsPer4B = BytesPerBank / ck_tile::gcd(BytesPerBank, DataTypeSize);
constexpr auto ToWords = [](index_t elems) constexpr {
return (elems * DataTypeSize) / BytesPerBank;
};
constexpr index_t BaseWords = ToWords(BaseStrideElems);
constexpr index_t PadWords = ((BaseWords % 2) == 0) ? 1 : 0;
constexpr auto PaddingAmount = PadWords * ElemsPer4B;
#else
constexpr auto PaddingAmount = 0;
#endif

constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NPerIterationShuffle / NLdsLayer>{},
number<MPerIterationShuffle / VectorLen * NLdsLayer>{},
number<VectorLen>{}),
make_tuple(number<MPerIterationShuffle * NLdsLayer + PaddingAmount>{},
number<VectorLen>{},
number<1>{}),
number<VectorLen>{},
number<1>{});

constexpr auto lds_block_desc_1 = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NPerIterationShuffle / NLdsLayer>{}),
make_unmerge_transform(make_tuple(
number<NLdsLayer>{}, number<MPerIterationShuffle / VectorLen>{})),
make_pass_through_transform(number<VectorLen>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));

constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
number<NPerIterationShuffle / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<MPerIterationShuffle / VectorLen>{}, number<VectorLen>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));

return lds_block_desc;
}
else
{
Expand Down