diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 8e6f875c39..ff0bb10d0c 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -34,9 +34,48 @@ using f4_t = unsigned _BitInt(4); using f6_t = _BitInt(6); // e2m3 format using bf6_t = unsigned _BitInt(6); // e3m2 format -// scalar_type -template -struct scalar_type; +// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, +// native types: bool +template +inline constexpr bool is_native_type() +{ + return is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v; +} + +/** + * @brief Wrapper for native vector type + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + */ +template +using NativeVectorT = T __attribute__((ext_vector_type(Rank))); + +/** + * @brief Mapping of incoming type to local native vector storage type and vector size + * @tparam T Incoming data type + */ +template +struct scalar_type +{ + // Basic data type mapping to unsigned _BitInt of appropriate size + using type = unsigned _BitInt(8 * sizeof(T)); + static constexpr index_t vector_size = 1; +}; + +/** + * @brief scalar_type trait override for NativeVectorT + * @tparam T The vector type + * @tparam Rank The number of elements in the vector + */ +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = Rank; +}; struct f4x2_pk_t { @@ -74,6 +113,39 @@ struct f4x2_pk_t } }; +// TODO: Unfortunately, we cannot partially specialize scalar_type for vectors written +// in the following way: +// template +// struct scalar_type +// { +// using type = T; +// static constexpr index_t vector_size = Rank; +// }; +// The compiler errors out with "partial specialization is not allowed for this type", +// claiming that the Rank is not a deducible parameter. This might be a compiler bug. +// Note the above type is classified differently from the NativeVectorT alias, +// even though they are functionally equivalent and are trivially constructibe from each other. +// This is unfortunate, but we have to work around it because some LLVM builtins for some +// operations (e.g., mma) may return the former type. +// For now we have to explicitly specialize for each vector size we need. These are used +// in f6_pk_t below. + +/// @brief scalar_type trait override for uint32_t vector of size 3 +template <> +struct scalar_type +{ + using type = uint32_t; + static constexpr index_t vector_size = 3; +}; + +/// @brief scalar_type trait override for uint32_t vector of size 6 +template <> +struct scalar_type +{ + using type = uint32_t; + static constexpr index_t vector_size = 6; +}; + template struct f6_pk_t { @@ -89,28 +161,48 @@ struct f6_pk_t static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; // 3 or 6 element_type units - using storage_type = element_type __attribute__((ext_vector_type(vector_size))); + using storage_type = NativeVectorT; storage_type data_{storage_type(0)}; // packed data using type = f6_pk_t; + /** This class may trivially constructed by the following vector type alias + * for example from a result of an mma operation. This is primarily for internal use. + * @note f6x16_pk_t and f6x32_pk_t storage types, may be trivially constructed from + * uint32_t vectors of size 3 and 6 respectively for example from mma operation results. + * Unfortunately, unsigned int __attribute__((ext_vector_type(6))) a.k.a + * NativeVectorT is NOT the same as __attribute__((__vector_size__(6 * + * sizeof(unsigned int)))) unsigned int which is returned from the mma ops despite being + * functionally equivalent. This class may be trivially constructed from both, so we can steer + * the templated ctor below to only consider incoming vectors types other than our two storage + * types of interest. + */ + using storage_type_alias = + element_type __attribute__((__vector_size__(sizeof(element_type) * vector_size))); + __host__ __device__ constexpr f6_pk_t() {} __host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init} { // TODO: consider removing initialization similar to vector_type } - // Initialize from a vector type with the same size as packed_size - template ::vector_size == packed_size>> + // Initialize from a vector type with the same size as packed_size. + // Exclude storage_type and storage_type_alias because these are trivially constructible. + template < + typename T, + typename = enable_if_t && !is_same_v && + scalar_type::vector_size == packed_size>> __host__ __device__ f6_pk_t(const T& v) { + static_assert(scalar_type::vector_size == packed_size, + "Input vector size must match packed_size."); static_for<0, packed_size, 1>{}( [&](auto i) { pack(v[static_cast(i)], static_cast(i)); }); } // Broadcast single initialization value to all packed elements __host__ __device__ f6_pk_t(const int8_t v) - : f6_pk_t(static_cast(v)) + : f6_pk_t(static_cast>(v)) { // TODO: consider removing initialization similar to vector_type } @@ -191,27 +283,6 @@ struct pk_i4_t __host__ __device__ constexpr pk_i4_t(type init) : data{init} {} }; -inline constexpr auto next_pow2(uint32_t x) -{ - // Precondition: x > 1. - return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; -} - -// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, -// native types: bool -template -inline constexpr bool is_native_type() -{ - return is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same_v || is_same_v || is_same::value; -} - -// scalar_type -template -struct scalar_type; - // is_scalar_type template struct is_scalar_type @@ -224,14 +295,13 @@ template using has_same_scalar_type = is_same>::type, typename scalar_type>::type>; -template -struct scalar_type +template <> +struct scalar_type { - using type = T; - static constexpr index_t vector_size = N; + using type = bool; + static constexpr index_t vector_size = 1; }; -// template <> struct scalar_type { @@ -293,35 +363,35 @@ struct scalar_type template <> struct scalar_type { - using type = pk_i4_t; + using type = typename pk_i4_t::type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f8_fnuz_t::data_type; + using type = typename f8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf8_fnuz_t::data_type; + using type = typename bf8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f8_ocp_t::data_type; + using type = typename f8_ocp_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf8_ocp_t::data_type; + using type = typename bf8_ocp_t::data_type; static constexpr index_t vector_size = 1; }; @@ -329,7 +399,7 @@ struct scalar_type template <> struct scalar_type { - using type = e8m0_bexp_t::type; + using type = typename e8m0_bexp_t::type; static constexpr index_t vector_size = 1; }; #endif @@ -337,42 +407,35 @@ struct scalar_type template <> struct scalar_type { - using type = f4x2_pk_t::type; + using type = typename f4x2_pk_t::type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f6x32_pk_t::storage_type; + using type = typename f6x32_pk_t::storage_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf6x32_pk_t::storage_type; + using type = typename bf6x32_pk_t::storage_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f6x16_pk_t::storage_type; + using type = typename f6x16_pk_t::storage_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf6x16_pk_t::storage_type; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = bool; + using type = typename bf6x16_pk_t::storage_type; static constexpr index_t vector_size = 1; }; diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 204b199629..8ec4602743 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -3,199 +3,85 @@ #pragma once #include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { -// vector_type -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type, N>; - -// vector_type_maker -// This is the right way to handle "vector of vectors": making a bigger vector instead -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct scalar_type> -{ - using type = T; - static constexpr index_t vector_size = N; -}; - -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct vector_type_maker, N0> -{ - using type = vector_type; -}; +__device__ int static err = 0; -template -using vector_type_maker_t = typename vector_type_maker::type; +template +struct non_native_vector_base; template -__host__ __device__ constexpr auto make_vector_type(Number) -{ - return typename vector_type_maker::type{}; -} - -template -struct vector_type()>> +struct non_native_vector_base< + T, + N, + ck::enable_if_t> { - using d1_t = T; - using type = d1_t; + using data_t = typename scalar_type::type; // select data_t based on the size of T + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + using data_v = NativeVectorT; + using type = non_native_vector_base; - union + union alignas(math::next_power_of_two()) { - T d1_; - StaticallyIndexedArray d1x1_; + data_v dN; // storage vector; + StaticallyIndexedArray_v2 dxN; + StaticallyIndexedArray_v2 dTxN; + StaticallyIndexedArray_v2 dNx1; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value, - "Something went wrong, please check src and dst types."); - - return data_.d1x1_; - } - - template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) { - static_assert(is_same::value, - "Something went wrong, please check src and dst types."); - - return data_.d1x1_; } -}; - -__device__ int static err = 0; -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - - using type = d2_t; - - union - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} - template - __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) + if constexpr(N == 1) { - return data_.d2x1_; + return data_.dxN[Number<0>{}]; } else { - return err; + return data_.dxN; // XXX this should cause an error } } - - template - __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] + __host__ __device__ constexpr operator T() const { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) + if constexpr(N == 1) { - return data_.d2x1_; + return data_.dTxN[Number<0>{}]; } else { - return err; + return data_.dTxN; // XXX this should cause an error } } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d3_t __attribute__((ext_vector_type(3))); - - using type = d3_t; - - union - { - d3_t d3_; - StaticallyIndexedArray d1x3_; - StaticallyIndexedArray d2x1_; - StaticallyIndexedArray d3x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same_v) { - return data_.d1x3_; + return data_.dxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d2x1_; + return data_.dTxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d3x1_; + return data_.dNx1; } else { @@ -204,22 +90,22 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same_v) { - return data_.d1x3_; + return data_.dxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d2x1_; + return data_.dTxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d3x1_; + return data_.dNx1; } else { @@ -228,114 +114,80 @@ struct vector_type()>> } }; -template -struct vector_type()>> +// implementation for f6x16 and f6x32 +template +struct non_native_vector_base< + T, + N, + ck::enable_if_t> { - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - - using type = d4_t; + using data_t = typename scalar_type::type; // select data_t based on declared base type + using element_t = typename T::element_type; // select element_t based on declared element type + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + static constexpr size_t size_factor = sizeof(data_t) / sizeof(element_t); + using data_v = NativeVectorT; + using type = non_native_vector_base; - union + union alignas(math::next_power_of_two()) { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; + data_v dN; // storage vector; + StaticallyIndexedArray_v2 dxN; + StaticallyIndexedArray_v2 dTxN; + StaticallyIndexedArray_v2 dNx1; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] + // Broadcast single value to vector + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{} { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); + // TODO: consider removing initialization similar to vector_type - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } + ck::static_for<0, N, 1>{}([&](auto i) { + data_.dxN(i) = a; // broadcast value to all elements + }); } - template - __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); + } - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr non_native_vector_base(element_t v) : data_{data_v(v)} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) { - return data_.d4x1_; + return data_.dTxN[Number<0>{}]; } else { - return err; + return err; // XXX this should cause an error } } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d5_t __attribute__((ext_vector_type(5))); - - using type = d5_t; - - union - { - d5_t d5_; - StaticallyIndexedArray d1x5_; - StaticallyIndexedArray d4x1_; - StaticallyIndexedArray d5x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same_v) { - return data_.d1x5_; + return data_.dNx1; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d4x1_; + return data_.dxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d5x1_; + return data_.dTxN; } else { @@ -344,22 +196,22 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ auto& AsType() { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same_v) { - return data_.d1x5_; + return data_.dNx1; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d4x1_; + return data_.dxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d5x1_; + return data_.dTxN; } else { @@ -368,1781 +220,397 @@ struct vector_type()>> } }; -template -struct vector_type()>> +template +struct scalar_type>> +{ + using type = typename non_native_vector_base::data_t; + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type>> { - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d3_t __attribute__((ext_vector_type(3))); - typedef T d6_t __attribute__((ext_vector_type(6))); + using type = typename non_native_vector_base::element_t; + static constexpr index_t vector_size = N * non_native_vector_base::size_factor; +}; - using type = d6_t; +/** + * @brief Helper struct to determine the storage type for vector_type + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + * @tparam Enable SFINAE helper + */ +template +struct vector_type_storage; + +/** + * @brief Vector storage type for native scalar types. + * @tparam T The element type of the vector + * @note For Rank = 1 and native types, the storage type is simply T itself (scalar) + */ +template +struct vector_type_storage()>> +{ + using type = T; +}; + +/** + * @brief Vector storage type for native vector types. + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + * + * Assigns a native vector type based on the element type and rank. + * For boolean types, uses a C-style array `T[Rank]`, otherwise uses + * the `NativeVectorT` template specialization. + * + * @note Special handling note: + * Sub-byte sizes such as bool have different sizes in ext_vector_type (via NativeVectorT) vs array + * types due to packing. Builtin vector types pack bool elements, while C++ arrays use 1 byte per + * bool as a standard (minimum write size = 1 byte). e.g., ext_vector_type(bool, 4) is packed as + * minimum 1 byte, while bool[4] is 4 bytes. vector_type::AsType, aliases with + * StaticallyIndexedArray_v2 which is C-style array under the hood, so we need to avoid using + * ext_vector_type with bool due to potential for data slicing errors. + */ +template +struct vector_type_storage() && (Rank > 1)>> +{ + using type = std::conditional_t, T[Rank], NativeVectorT>; +}; + +/** + * @brief Vector storage type for non-native vector types. + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + * @note For non-native types, the storage type is non_native_vector_base + */ +template +struct vector_type_storage()>> +{ + using type = non_native_vector_base; +}; + +/** + * @brief Convenience wrapper for vector_type_storage + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + */ +template +using vector_type_storage_t = typename vector_type_storage::type; + +/** + * @brief Trait to check whether one vector storage class is the same as another (e.g., same scalar, + * or same vector class). + * @tparam Lhs The source storage type + * @tparam Rhs The comparator storage type + * + * Same storage classes are: + * - Same type + * - Same template vector types with matching base type (may have different ranks) + * - C-style arrays of same base type (may have different ranks) + */ +template +struct is_same_vector_storage_class : public false_type +{ +}; + +/** + * @brief Template native vector types of same base type with different ranks + * @tparam T The base element type + * @tparam LhsRank The rank of the source type + * @tparam RhsRank The rank of the comparator type + */ +template +struct is_same_vector_storage_class, NativeVectorT> + : true_type +{ +}; + +/** + * @brief Template non-native vector types of same base type with different ranks + * @tparam T The base element type + * @tparam LhsRank The rank of the source type + * @tparam RhsRank The rank of the comparator type + * @tparam Enable SFINAE helper + */ +template +struct is_same_vector_storage_class, + non_native_vector_base> : true_type +{ +}; + +/** + * @brief C-style arrays of same base type with different ranks + * @tparam T The base element type + * @tparam LhsRank The rank of the source type + * @tparam RhsRank The rank of the comparator type + */ +template +struct is_same_vector_storage_class : true_type +{ +}; + +/** + * @brief Convenience evaluator for is_same_vector_storage_class + * @tparam Lhs The source storage type + * @tparam Rhs The comparator storage type + */ +template +static constexpr bool is_same_vector_storage_class_v = + is_same_vector_storage_class::value; + +// Fwd declaration +template +struct vector_type; - union +/** + * @brief Trait to extract element type and rank from vector_type and related types + * @tparam T The vector type + */ +template +struct vector_type_traits +{ + using element_type = T; + static constexpr index_t Rank = 1; +}; + +/** + * @brief Specialization of vector_type_traits for vector_type + * @tparam T The element type of the vector + * @tparam Rank_ The number of elements in the vector + */ +template +struct vector_type_traits> +{ + using element_type = T; + static constexpr index_t Rank = Rank_; +}; + +/** + * @brief Specialization of vector_type_traits for non_native_vector_base + * @tparam T The element type of the vector + * @tparam Rank_ The number of elements in the vector + */ +template +struct vector_type_traits> +{ + using element_type = T; + static constexpr index_t Rank = Rank_; +}; + +/** + * @brief Specialization of vector_type_traits for NativeVectorT + * @tparam T The element type of the vector + * @tparam Rank_ The number of elements in the vector + */ +template +struct vector_type_traits> +{ + using element_type = T; + static constexpr index_t Rank = Rank_; +}; + +/** + * @brief Vector type wrapper + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + */ +template +struct vector_type +{ + /// @brief Internal storage type for vector_type. + using StorageT = vector_type_storage_t; + using type = StorageT; + StorageT data_; + + /// @brief Default constructor for vector_type + __host__ __device__ constexpr vector_type() : data_{} {} + + /// @brief Constructor for native vector initialization + __host__ __device__ constexpr vector_type(StorageT v) : data_{v} {} + + /** + * @brief Validates whether a type can be used in an AsType cast operation for vector_type + * class. + * + * This function checks if a given type X can be legally used as an alias to either reinterpret + * or slice (iterate) through the local storage type StorageT. The validation ensures type + * safety and structural compatibility between the source and target vector types. + * + * @tparam X The target type to validate for AsType casting. + * + * @return constexpr bool True if the type is valid for AsType casting, false otherwise + * + * @note Requirements for a valid AsType cast on vector_type: + * 1. The value type of X must match the storage value type (T) + * 2. X must be either: + * a) A scalar type (T) where RankX == 1, OR + * b) A vector class that matches the storage vector class (e.g., both are + * NativeVectorT or non_native_vector_base) where: + * - RankX is a power of 2, OR + * RankX == 3, OR + * RankX == Storage Rank + * - RankX must be <= Storage Rank + * @example + * auto srcVec = vector_type{}; // T = float, Rank = 8, native vector storage + * auto result = srcVec.AsType(); // Where datatype X could be: + * X = NativeVectorT; // OK: native vector T, RankX = 4 (power of 2) + * X = float; // OK: scalar T, RankX = 1 + * X = NativeVectorT; // ERROR: RankX not a power of 2, ==3, or ==Rank + * X = int; // ERROR: Invalid scalar cast, T != int + * X = float[4]; // ERROR: Invalid type, storage vector class doesn't + * // match (native vector != C-array) + */ + template + static constexpr bool is_as_type_cast_valid() + { + using TraitsX = vector_type_traits; + + // Checks storage classes match, with same base type (may have different ranks) + constexpr bool is_valid_cast = + is_same_vector_storage_class_v || // Matching vector storage + is_same_v; // Matching scalar type + + // Validate vector ranks + constexpr bool is_valid_rank = (math::is_power_of_two_integer(TraitsX::Rank) || + (TraitsX::Rank == 3) || (TraitsX::Rank == Rank)) && + (TraitsX::Rank <= Rank); + + return is_valid_cast && is_valid_rank; + } + + /** + * @brief Allows casting the vector_type to another type X via aliasing or slicing. + * Use cases are to expose the internal storage type, or to slice the vector into smaller + * vectors for iteration purposes. + * @tparam X The target type to validate for AsType casting. + * @returns a reference to the reinterpreted data as StaticallyIndexedArray_v2. + * Rigid control of allowable casts is enforced via static_assert to ensure type safety. + * See is_as_type_cast_valid() for requirements. + */ + template + __host__ __device__ constexpr auto const& AsType() const [[clang::lifetimebound]] + { + // Make this a hard error if the datatype X is not a valid cast. + static_assert(is_as_type_cast_valid(), "Datatype X is not a valid AsType cast"); + + using TraitsX = vector_type_traits; + + // Calculate the new rank after slicing. + // Note: We might end up with incomplete quantization from slicing + // when Rank % TraitsX::Rank != 0, so take the floor division. + constexpr index_t newRank = Rank / TraitsX::Rank; + + // Determine the cast type: + // - Scalar T if slicing to scalar or vector size of 1, + // - X otherwise. + using CastT = conditional_t; + using ResultT = StaticallyIndexedArray_v2; + + // As a rule, the aliasing type should not be larger than the original type. + static_assert(sizeof(ResultT) <= sizeof(vector_type), + "Resulting aliasing cannot be larger than original type"); + + // Re-cast as vectorized type. + return *(bit_cast(this)); + } + + /** + * @brief Allows casting the vector_type to another type X via aliasing or slicing. + * Use cases are to expose the internal storage type, or to slice the vector into smaller + * vectors for iteration purposes. + * @tparam X The target type to validate for AsType casting. + * @returns a reference to the reinterpreted data as StaticallyIndexedArray_v2. + * Rigid control of allowable casts is enforced via static_assert to ensure type safety. + * See is_as_type_cast_valid() for requirements. + */ + template + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { - d6_t d6_; - StaticallyIndexedArray d1x6_; - StaticallyIndexedArray d2x3_; - StaticallyIndexedArray d3x2_; - StaticallyIndexedArray d6x1_; - } data_; + // Make this a hard error if the datatype X is not a valid cast. + static_assert(is_as_type_cast_valid(), "Datatype X is not a valid AsType cast"); - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + using TraitsX = vector_type_traits; - __host__ __device__ constexpr vector_type(type v) : data_{v} {} + // Calculate the new rank after slicing. + // Note: We might end up with incomplete quantization from slicing + // when Rank % TraitsX::Rank != 0, so take the floor division. + constexpr index_t newRank = Rank / TraitsX::Rank; - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); + // Determine the cast type: + // - Scalar T if slicing to scalar or vector size of 1, + // - X otherwise. + using CastT = conditional_t; + using ResultT = StaticallyIndexedArray_v2; - if constexpr(is_same::value) - { - return data_.d1x6_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d3x2_; - } - else if constexpr(is_same::value) - { - return data_.d6x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); + // As a rule, the aliasing type should not be larger than the original type. + static_assert(sizeof(ResultT) <= sizeof(vector_type), + "Resulting aliasing cannot be larger than original type"); - if constexpr(is_same::value) - { - return data_.d1x6_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d3x2_; - } - else if constexpr(is_same::value) - { - return data_.d6x1_; - } - else - { - return err; - } + // Re-cast as vectorized type. + return *(bit_cast(this)); } }; -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d7_t __attribute__((ext_vector_type(7))); - - using type = d7_t; - - union - { - d7_t d7_; - StaticallyIndexedArray d1x7_; - StaticallyIndexedArray d2x3_; - StaticallyIndexedArray d4x1_; - StaticallyIndexedArray d7x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x7_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d7x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; - if constexpr(is_same::value) - { - return data_.d1x7_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d7x1_; - } - else - { - return err; - } - } -}; +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; -template -struct vector_type()>> +/** + * @brief scalar_type trait override for vector_type + * @tparam T The vector type + * @tparam N The number of elements in the vector + */ +template +struct scalar_type> { - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - - using type = d8_t; - - union - { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } + using type = typename scalar_type::type; + static constexpr index_t vector_size = N; }; -template -struct vector_type()>> +// vector_type_maker +// This is the right way to handle "vector of vectors": making a bigger vector instead +template +struct vector_type_maker { - using d1_t = T; - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d13_t __attribute__((ext_vector_type(13))); - - using type = d13_t; - - union - { - d13_t d13_; - StaticallyIndexedArray d1x13_; - StaticallyIndexedArray d4x3_; - StaticallyIndexedArray d8x1_; - StaticallyIndexedArray d13x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x13_; - } - else if constexpr(is_same::value) - { - return data_.d4x3_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else if constexpr(is_same::value) - { - return data_.d13x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x13_; - } - else if constexpr(is_same::value) - { - return data_.d4x3_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else if constexpr(is_same::value) - { - return data_.d13x1_; - } - else - { - return err; - } - } + using type = vector_type; }; -template -struct vector_type()>> +template +struct vector_type_maker, N0> { - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - - using type = d16_t; - - union - { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } + using type = vector_type; }; -template -struct vector_type()>> +template +struct vector_type_maker, N0> { - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - - using type = d32_t; - - union - { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; - } data_ = {d32_t{0}}; - - __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - - __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } - - // __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - // __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); + using type = vector_type; +}; - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - - using type = d64_t; - - union - { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - - using type = d128_t; - - union - { - d128_t d128_; - StaticallyIndexedArray d1x128_; - StaticallyIndexedArray d2x64_; - StaticallyIndexedArray d4x32_; - StaticallyIndexedArray d8x16_; - StaticallyIndexedArray d16x8_; - StaticallyIndexedArray d32x4_; - StaticallyIndexedArray d64x2_; - StaticallyIndexedArray d128x1_; - } data_ = {d128_t{0}}; - - __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - - __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x128_; - } - else if constexpr(is_same::value) - { - return data_.d2x64_; - } - else if constexpr(is_same::value) - { - return data_.d4x32_; - } - else if constexpr(is_same::value) - { - return data_.d8x16_; - } - else if constexpr(is_same::value) - { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x128_; - } - else if constexpr(is_same::value) - { - return data_.d2x64_; - } - else if constexpr(is_same::value) - { - return data_.d4x32_; - } - else if constexpr(is_same::value) - { - return data_.d8x16_; - } - else if constexpr(is_same::value) - { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - typedef T d256_t __attribute__((ext_vector_type(256))); - - using type = d256_t; - - union - { - d256_t d256_; - StaticallyIndexedArray d1x256_; - StaticallyIndexedArray d2x128_; - StaticallyIndexedArray d4x64_; - StaticallyIndexedArray d8x32_; - StaticallyIndexedArray d16x16_; - StaticallyIndexedArray d32x8_; - StaticallyIndexedArray d64x4_; - StaticallyIndexedArray d128x2_; - StaticallyIndexedArray d256x1_; - } data_ = {d256_t{0}}; - - __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - - __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x256_; - } - else if constexpr(is_same::value) - { - return data_.d2x128_; - } - else if constexpr(is_same::value) - { - return data_.d4x64_; - } - else if constexpr(is_same::value) - { - return data_.d8x32_; - } - else if constexpr(is_same::value) - { - return data_.d16x16_; - } - else if constexpr(is_same::value) - { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x256_; - } - else if constexpr(is_same::value) - { - return data_.d2x128_; - } - else if constexpr(is_same::value) - { - return data_.d4x64_; - } - else if constexpr(is_same::value) - { - return data_.d8x32_; - } - else if constexpr(is_same::value) - { - return data_.d16x16_; - } - else if constexpr(is_same::value) - { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; - } - else - { - return err; - } - } -}; - -template -struct non_native_vector_base; - -template -struct nnvb_data_t_selector -{ - using type = unsigned _BitInt(8 * sizeof(T)); -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f8_ocp_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf8_ocp_t::data_type; -}; - -#ifndef CK_CODE_GEN_RTC -template <> -struct nnvb_data_t_selector -{ - using type = f8_fnuz_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf8_fnuz_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = e8m0_bexp_t::type; -}; -#endif - -template <> -struct nnvb_data_t_selector -{ - using type = f6x16_pk_t::storage_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f6x32_pk_t::storage_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf6x16_pk_t::storage_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf6x32_pk_t::storage_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = pk_i4_t::type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f4x2_pk_t::type; -}; +template +using vector_type_maker_t = typename vector_type_maker::type; template -struct non_native_vector_base< - T, - N, - ck::enable_if_t> +__host__ __device__ constexpr auto make_vector_type(Number) { - using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T - static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); - using data_v = data_t __attribute__((ext_vector_type(N))); - using type = non_native_vector_base; - - union alignas(next_pow2(N * sizeof(T))) - { - data_v dN; // storage vector; - StaticallyIndexedArray dxN; - StaticallyIndexedArray dTxN; - StaticallyIndexedArray dNx1; - } data_; - - __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} - __host__ __device__ constexpr non_native_vector_base(T f) - : non_native_vector_base(bit_cast(f)) - { - } - __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; - __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} - - __host__ __device__ constexpr operator data_v() const { return data_.dN; } - __host__ __device__ constexpr operator data_t() const - { - if constexpr(N == 1) - { - return data_.dxN[Number<0>{}]; - } - else - { - return data_.dxN; // XXX this should cause an error - } - } - __host__ __device__ constexpr operator T() const - { - if constexpr(N == 1) - { - return data_.dTxN[Number<0>{}]; - } - else - { - return data_.dTxN; // XXX this should cause an error - } - } - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same_v || is_same_v || is_same_v, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same_v) - { - return data_.dxN; - } - else if constexpr(is_same_v) - { - return data_.dTxN; - } - else if constexpr(is_same_v) - { - return data_.dNx1; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] - { - static_assert(is_same_v || is_same_v || is_same_v, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same_v) - { - return data_.dxN; - } - else if constexpr(is_same_v) - { - return data_.dTxN; - } - else if constexpr(is_same_v) - { - return data_.dNx1; - } - else - { - return err; - } - } -}; - -// implementation for f6x16 and f6x32 -template -struct non_native_vector_base< - T, - N, - ck::enable_if_t> -{ - using data_t = - typename nnvb_data_t_selector::type; // select data_t based on declared base type - using element_t = typename T::element_type; // select element_t based on declared element type - static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); - static constexpr size_t size_factor = sizeof(data_t) / sizeof(element_t); - using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); - using type = non_native_vector_base; - - union alignas(next_pow2(N * sizeof(T))) - { - data_v dN; // storage vector; - StaticallyIndexedArray dxN; - StaticallyIndexedArray dTxN; - StaticallyIndexedArray dNx1; - } data_; - - // Broadcast single value to vector - __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{} - { - // TODO: consider removing initialization similar to vector_type - - ck::static_for<0, N, 1>{}([&](auto i) { - data_.dxN(i) = a; // broadcast value to all elements - }); - } - - __host__ __device__ constexpr non_native_vector_base(T f) - : non_native_vector_base(bit_cast(f)) - { - } - - __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; - - __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} - - __host__ __device__ constexpr non_native_vector_base(element_t v) : data_{data_v(v)} {} - - __host__ __device__ constexpr operator data_v() const { return data_.dN; } - - __host__ __device__ constexpr operator T() const - { - if constexpr(N == 1) - { - return data_.dTxN[Number<0>{}]; - } - else - { - return err; // XXX this should cause an error - } - } - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same_v || is_same_v || is_same_v, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same_v) - { - return data_.dNx1; - } - else if constexpr(is_same_v) - { - return data_.dxN; - } - else if constexpr(is_same_v) - { - return data_.dTxN; - } - else - { - return err; - } - } -}; - -template -struct scalar_type>> -{ - using type = typename non_native_vector_base::data_t; - static constexpr index_t vector_size = N; -}; - -template -struct scalar_type>> -{ - using type = typename non_native_vector_base::element_t; - static constexpr index_t vector_size = N * non_native_vector_base::size_factor; -}; - -// non-native vector_type implementation -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using type = d1_nnv_t; - - union alignas(next_pow2(1 * sizeof(T))) - { - d1_t d1_; - StaticallyIndexedArray d1x1_; - d1_nnv_t d1_nnv_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{d1_t{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - - using type = d2_t; - - union alignas(next_pow2(2 * sizeof(T))) - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - - using type = d4_t; - - union alignas(next_pow2(4 * sizeof(T))) - { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - - using type = d8_t; - - union alignas(next_pow2(8 * sizeof(T))) - { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - - using type = d16_t; - - union alignas(next_pow2(16 * sizeof(T))) - { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - using d32_t = non_native_vector_base; - - using type = d32_t; - - union alignas(next_pow2(32 * sizeof(T))) - { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - using d32_t = non_native_vector_base; - using d64_t = non_native_vector_base; - - using type = d64_t; - - union alignas(next_pow2(64 * sizeof(T))) - { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } -}; + return typename vector_type_maker::type{}; +} // fp32 using float2_t = typename vector_type::type; diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 4e477eed26..00fab270e8 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -260,7 +260,10 @@ struct DynamicBuffer x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && - is_same>::type, int8_t>::value && + is_same_v>::type, int8_t> && + !is_same_v, + pk_i4_t> && // TODO: This needs to be fixed for pk_i4_t which + // cannot be handled below, but is stored as int8_t workaround_int8_ds_write_issue) { if(is_valid_element) diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index b2ebf4b371..f43a2c0f9f 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -222,16 +222,27 @@ template __host__ __device__ constexpr auto next_power_of_two() { // TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail - constexpr index_t Y = 1 << (32 - __builtin_clz(X - 1)); + constexpr index_t Y = X > 1 ? (1 << (32 - __builtin_clz(X - 1))) : X; return Y; } template -__host__ __device__ constexpr auto next_power_of_two(Number x) +__host__ __device__ constexpr auto next_power_of_two(Number) { - // TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail - constexpr index_t Y = 1 << (32 - __builtin_clz(x.value - 1)); - return Number{}; + return Number>{}; +} + +__host__ __device__ constexpr int32_t integer_log2_floor(int32_t x) +{ + // TODO: x need to be 1 ~ 0x7fffffff + // __builtin_clz will produce unexpected result if x is 0; + return 31 - __builtin_clz(x); +} + +__host__ __device__ constexpr bool is_power_of_two_integer(int32_t x) +{ + // TODO: x need to be 1 ~ 0x7fffffff + return x == (1 << integer_log2_floor(x)); } } // namespace math diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 161c4d37c3..11f0053585 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1841,7 +1841,7 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0 float float_array[32]; } in{x}; - using array_type = uint8_t __attribute__((ext_vector_type(32))); + using array_type = NativeVectorT; array_type uint8_array; // collect the 6-bit values into an array @@ -2178,7 +2178,7 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1 float float_array[32]; } in{x}; - using array_type = uint8_t __attribute__((ext_vector_type(32))); + using array_type = NativeVectorT; array_type uint8_array; // collect the 6-bit values into an array