Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 114 additions & 51 deletions include/ck/utility/data_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,48 @@ using f4_t = unsigned _BitInt(4);
using f6_t = _BitInt(6); // e2m3 format
using bf6_t = unsigned _BitInt(6); // e3m2 format

// scalar_type
template <typename TV>
struct scalar_type;
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same_v<T, double> || is_same_v<T, float> || is_same_v<T, half_t> ||
is_same_v<T, bhalf_t> || is_same_v<T, int32_t> || is_same_v<T, uint32_t> ||
is_same_v<T, int8_t> || is_same_v<T, uint8_t> || is_same_v<T, _BitInt(8)> ||
is_same_v<T, unsigned _BitInt(8)> || is_same_v<T, bool>;
}

/**
* @brief Wrapper for native vector type
* @tparam T The element type of the vector
* @tparam Rank The number of elements in the vector
*/
template <typename T, index_t Rank>
using NativeVectorT = T __attribute__((ext_vector_type(Rank)));

/**
* @brief Mapping of incoming type to local native vector storage type and vector size
* @tparam T Incoming data type
*/
template <typename T>
struct scalar_type
{
// Basic data type mapping to unsigned _BitInt of appropriate size
using type = unsigned _BitInt(8 * sizeof(T));
static constexpr index_t vector_size = 1;
};

/**
* @brief scalar_type trait override for NativeVectorT
* @tparam T The vector type
* @tparam Rank The number of elements in the vector
*/
template <typename T, index_t Rank>
struct scalar_type<NativeVectorT<T, Rank>>
{
using type = T;
static constexpr index_t vector_size = Rank;
};

struct f4x2_pk_t
{
Expand Down Expand Up @@ -74,6 +113,39 @@ struct f4x2_pk_t
}
};

// TODO: Unfortunately, we cannot partially specialize scalar_type for vectors written
// in the following way:
// template<typename T, index_t Rank>
// struct scalar_type<T __attribute__((__vector_size__(sizeof(T) * Rank)))>
// {
// using type = T;
// static constexpr index_t vector_size = Rank;
// };
// The compiler errors out with "partial specialization is not allowed for this type",
// claiming that the Rank is not a deducible parameter. This might be a compiler bug.
// Note the above type is classified differently from the NativeVectorT<T, Rank> alias,
// even though they are functionally equivalent and are trivially constructibe from each other.
// This is unfortunate, but we have to work around it because some LLVM builtins for some
// operations (e.g., mma) may return the former type.
// For now we have to explicitly specialize for each vector size we need. These are used
// in f6_pk_t below.

/// @brief scalar_type trait override for uint32_t vector of size 3
template <>
struct scalar_type<uint32_t __attribute__((__vector_size__(sizeof(uint32_t) * 3)))>
{
using type = uint32_t;
static constexpr index_t vector_size = 3;
};

/// @brief scalar_type trait override for uint32_t vector of size 6
template <>
struct scalar_type<uint32_t __attribute__((__vector_size__(sizeof(uint32_t) * 6)))>
{
using type = uint32_t;
static constexpr index_t vector_size = 6;
};

template <typename BitType, index_t pk_size>
struct f6_pk_t
{
Expand All @@ -89,28 +161,48 @@ struct f6_pk_t
static constexpr index_t vector_size =
(packed_size * num_bits_elem) / num_bits_vec_elem; // 3 or 6 element_type units

using storage_type = element_type __attribute__((ext_vector_type(vector_size)));
using storage_type = NativeVectorT<element_type, vector_size>;
storage_type data_{storage_type(0)}; // packed data

using type = f6_pk_t<BitType, packed_size>;

/** This class may trivially constructed by the following vector type alias
* for example from a result of an mma operation. This is primarily for internal use.
* @note f6x16_pk_t and f6x32_pk_t storage types, may be trivially constructed from
* uint32_t vectors of size 3 and 6 respectively for example from mma operation results.
* Unfortunately, unsigned int __attribute__((ext_vector_type(6))) a.k.a
* NativeVectorT<uint32_t, 6> is NOT the same as __attribute__((__vector_size__(6 *
* sizeof(unsigned int)))) unsigned int which is returned from the mma ops despite being
* functionally equivalent. This class may be trivially constructed from both, so we can steer
* the templated ctor below to only consider incoming vectors types other than our two storage
* types of interest.
*/
using storage_type_alias =
element_type __attribute__((__vector_size__(sizeof(element_type) * vector_size)));

__host__ __device__ constexpr f6_pk_t() {}
__host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init}
{
// TODO: consider removing initialization similar to vector_type<T, 256>
}

// Initialize from a vector type with the same size as packed_size
template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>>
// Initialize from a vector type with the same size as packed_size.
// Exclude storage_type and storage_type_alias because these are trivially constructible.
template <
typename T,
typename = enable_if_t<!is_same_v<T, storage_type> && !is_same_v<T, storage_type_alias> &&
scalar_type<T>::vector_size == packed_size>>
__host__ __device__ f6_pk_t(const T& v)
{
static_assert(scalar_type<T>::vector_size == packed_size,
"Input vector size must match packed_size.");
static_for<0, packed_size, 1>{}(
[&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); });
}

// Broadcast single initialization value to all packed elements
__host__ __device__ f6_pk_t(const int8_t v)
: f6_pk_t(static_cast<int8_t __attribute__((ext_vector_type(packed_size)))>(v))
: f6_pk_t(static_cast<NativeVectorT<int8_t, packed_size>>(v))
{
// TODO: consider removing initialization similar to vector_type<T, 256>
}
Expand Down Expand Up @@ -191,27 +283,6 @@ struct pk_i4_t
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
};

inline constexpr auto next_pow2(uint32_t x)
{
// Precondition: x > 1.
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}

// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value ||
is_same<T, uint32_t>::value || is_same<T, int8_t>::value || is_same<T, uint8_t>::value ||
is_same_v<T, _BitInt(8)> || is_same_v<T, unsigned _BitInt(8)> || is_same<T, bool>::value;
}

// scalar_type
template <typename TV>
struct scalar_type;

// is_scalar_type
template <typename TV>
struct is_scalar_type
Expand All @@ -224,14 +295,13 @@ template <typename X, typename Y>
using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<Y>>::type>;

template <typename T, index_t N>
struct scalar_type<T __attribute__((ext_vector_type(N)))>
template <>
struct scalar_type<bool>
{
using type = T;
static constexpr index_t vector_size = N;
using type = bool;
static constexpr index_t vector_size = 1;
};

//
template <>
struct scalar_type<double>
{
Expand Down Expand Up @@ -293,86 +363,79 @@ struct scalar_type<int4_t>
template <>
struct scalar_type<pk_i4_t>
{
using type = pk_i4_t;
using type = typename pk_i4_t::type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<f8_fnuz_t>
{
using type = f8_fnuz_t::data_type;
using type = typename f8_fnuz_t::data_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<bf8_fnuz_t>
{
using type = bf8_fnuz_t::data_type;
using type = typename bf8_fnuz_t::data_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<f8_ocp_t>
{
using type = f8_ocp_t::data_type;
using type = typename f8_ocp_t::data_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<bf8_ocp_t>
{
using type = bf8_ocp_t::data_type;
using type = typename bf8_ocp_t::data_type;
static constexpr index_t vector_size = 1;
};

#ifndef CK_CODE_GEN_RTC
template <>
struct scalar_type<e8m0_bexp_t>
{
using type = e8m0_bexp_t::type;
using type = typename e8m0_bexp_t::type;
static constexpr index_t vector_size = 1;
};
#endif

template <>
struct scalar_type<f4x2_pk_t>
{
using type = f4x2_pk_t::type;
using type = typename f4x2_pk_t::type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<f6x32_pk_t>
{
using type = f6x32_pk_t::storage_type;
using type = typename f6x32_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<bf6x32_pk_t>
{
using type = bf6x32_pk_t::storage_type;
using type = typename bf6x32_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<f6x16_pk_t>
{
using type = f6x16_pk_t::storage_type;
using type = typename f6x16_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<bf6x16_pk_t>
{
using type = bf6x16_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<bool>
{
using type = bool;
using type = typename bf6x16_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};

Expand Down
Loading