From 40e4d4697bb3f848a87c11d8ae38cbc702050afb Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Thu, 22 Jan 2026 20:30:26 +0000 Subject: [PATCH 1/7] Refactor vector_type to reduce build time --- include/ck/utility/data_type.hpp | 62 +- include/ck/utility/dtype_vector.hpp | 2404 +++-------------- include/ck/utility/math.hpp | 21 +- .../ck/utility/statically_indexed_array.hpp | 1 + 4 files changed, 445 insertions(+), 2043 deletions(-) diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 8e6f875c399..f37437336ad 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -34,9 +34,17 @@ using f4_t = unsigned _BitInt(4); using f6_t = _BitInt(6); // e2m3 format using bf6_t = unsigned _BitInt(6); // e3m2 format -// scalar_type -template -struct scalar_type; +/** + * @brief Mapping of incoming type to local native storage type and vector size + * @tparam T Incoming data type + */ +template +struct scalar_type +{ + // Basic data type mapping to unsigned _BitInt of appropriate size + using type = unsigned _BitInt(8 * sizeof(T)); + static constexpr index_t vector_size = 1; +}; struct f4x2_pk_t { @@ -191,12 +199,6 @@ struct pk_i4_t __host__ __device__ constexpr pk_i4_t(type init) : data{init} {} }; -inline constexpr auto next_pow2(uint32_t x) -{ - // Precondition: x > 1. - return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; -} - // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, // native types: bool template @@ -208,10 +210,6 @@ inline constexpr bool is_native_type() is_same_v || is_same_v || is_same::value; } -// scalar_type -template -struct scalar_type; - // is_scalar_type template struct is_scalar_type @@ -224,14 +222,13 @@ template using has_same_scalar_type = is_same>::type, typename scalar_type>::type>; -template -struct scalar_type +template <> +struct scalar_type { - using type = T; - static constexpr index_t vector_size = N; + using type = bool; + static constexpr index_t vector_size = 1; }; -// template <> struct scalar_type { @@ -293,35 +290,35 @@ struct scalar_type template <> struct scalar_type { - using type = pk_i4_t; + using type = typename pk_i4_t::type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f8_fnuz_t::data_type; + using type = typename f8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf8_fnuz_t::data_type; + using type = typename bf8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f8_ocp_t::data_type; + using type = typename f8_ocp_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf8_ocp_t::data_type; + using type = typename bf8_ocp_t::data_type; static constexpr index_t vector_size = 1; }; @@ -329,7 +326,7 @@ struct scalar_type template <> struct scalar_type { - using type = e8m0_bexp_t::type; + using type = typename e8m0_bexp_t::type; static constexpr index_t vector_size = 1; }; #endif @@ -337,42 +334,35 @@ struct scalar_type template <> struct scalar_type { - using type = f4x2_pk_t::type; + using type = typename f4x2_pk_t::type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f6x32_pk_t::storage_type; + using type = typename f6x32_pk_t::storage_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf6x32_pk_t::storage_type; + using type = typename bf6x32_pk_t::storage_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f6x16_pk_t::storage_type; + using type = typename f6x16_pk_t::storage_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf6x16_pk_t::storage_type; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = bool; + using type = typename bf6x16_pk_t::storage_type; static constexpr index_t vector_size = 1; }; diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index ebdbbb107d7..85f50b11d30 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -3,197 +3,103 @@ #pragma once #include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" namespace ck { -// vector_type -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type, N>; - -// vector_type_maker -// This is the right way to handle "vector of vectors": making a bigger vector instead -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct scalar_type> +/** + * @brief Wrapper for native vector type + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + */ +template +using NativeVectorT = T __attribute__((ext_vector_type(Rank))); + +/** + * @brief scalar_type trait override for NativeVectorT + * @tparam T The vector type + * @tparam Rank The number of elements in the vector + */ +template +struct scalar_type> { using type = T; - static constexpr index_t vector_size = N; -}; - -template -struct vector_type_maker -{ - using type = vector_type; + static constexpr index_t vector_size = Rank; }; -template -struct vector_type_maker, N0> -{ - using type = vector_type; -}; +__device__ int static err = 0; -template -using vector_type_maker_t = typename vector_type_maker::type; +template +struct non_native_vector_base; template -__host__ __device__ constexpr auto make_vector_type(Number) -{ - return typename vector_type_maker::type{}; -} - -template -struct vector_type()>> +struct non_native_vector_base< + T, + N, + ck::enable_if_t> { - using d1_t = T; - using type = d1_t; + using data_t = typename scalar_type::type; // select data_t based on the size of T + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + using data_v = data_t __attribute__((ext_vector_type(N))); + using type = non_native_vector_base; - union + union alignas(math::next_power_of_two()) { - T d1_; - StaticallyIndexedArray d1x1_; + data_v dN; // storage vector; + StaticallyIndexedArray_v2 dxN; + StaticallyIndexedArray_v2 dTxN; + StaticallyIndexedArray_v2 dNx1; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value, - "Something went wrong, please check src and dst types."); - - return data_.d1x1_; - } - - template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) { - static_assert(is_same::value, - "Something went wrong, please check src and dst types."); - - return data_.d1x1_; } -}; - -__device__ int static err = 0; -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - - using type = d2_t; - - union - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} - template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) + if constexpr(N == 1) { - return data_.d2x1_; + return data_.dxN[Number<0>{}]; } else { - return err; + return data_.dxN; // XXX this should cause an error } } - - template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr operator T() const { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) + if constexpr(N == 1) { - return data_.d2x1_; + return data_.dTxN[Number<0>{}]; } else { - return err; + return data_.dTxN; // XXX this should cause an error } } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d3_t __attribute__((ext_vector_type(3))); - - using type = d3_t; - - union - { - d3_t d3_; - StaticallyIndexedArray d1x3_; - StaticallyIndexedArray d2x1_; - StaticallyIndexedArray d3x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same_v) { - return data_.d1x3_; + return data_.dxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d2x1_; + return data_.dTxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d3x1_; + return data_.dNx1; } else { @@ -204,20 +110,20 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same_v) { - return data_.d1x3_; + return data_.dxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d2x1_; + return data_.dTxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d3x1_; + return data_.dNx1; } else { @@ -226,138 +132,80 @@ struct vector_type()>> } }; -template -struct vector_type()>> +// implementation for f6x16 and f6x32 +template +struct non_native_vector_base< + T, + N, + ck::enable_if_t> { - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - - using type = d4_t; + using data_t = typename scalar_type::type; // select data_t based on declared base type + using element_t = typename T::element_type; // select element_t based on declared element type + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + static constexpr size_t size_factor = sizeof(data_t) / sizeof(element_t); + using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); + using type = non_native_vector_base; - union + union alignas(math::next_power_of_two()) { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; + data_v dN; // storage vector; + StaticallyIndexedArray_v2 dxN; + StaticallyIndexedArray_v2 dTxN; + StaticallyIndexedArray_v2 dNx1; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const + // Broadcast single value to vector + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{} { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); + // TODO: consider removing initialization similar to vector_type - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } + ck::static_for<0, N, 1>{}([&](auto i) { + data_.dxN(i) = a; // broadcast value to all elements + }); } - template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } } -}; -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d5_t __attribute__((ext_vector_type(5))); - - using type = d5_t; + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; - union - { - d5_t d5_; - StaticallyIndexedArray d1x5_; - StaticallyIndexedArray d4x1_; - StaticallyIndexedArray d5x1_; - } data_; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr non_native_vector_base(element_t v) : data_{data_v(v)} {} - __host__ __device__ constexpr vector_type(type v) : data_{v} {} + __host__ __device__ constexpr operator data_v() const { return data_.dN; } - template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr operator T() const { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x5_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) + if constexpr(N == 1) { - return data_.d5x1_; + return data_.dTxN[Number<0>{}]; } else { - return err; + return err; // XXX this should cause an error } } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same_v) { - return data_.d1x5_; + return data_.dNx1; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d4x1_; + return data_.dxN; } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return data_.d5x1_; + return data_.dTxN; } else { @@ -366,1781 +214,333 @@ struct vector_type()>> } }; -template -struct vector_type()>> +template +struct scalar_type>> { - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d3_t __attribute__((ext_vector_type(3))); - typedef T d6_t __attribute__((ext_vector_type(6))); + using type = typename non_native_vector_base::data_t; + static constexpr index_t vector_size = N; +}; - using type = d6_t; +template +struct scalar_type>> +{ + using type = typename non_native_vector_base::element_t; + static constexpr index_t vector_size = N * non_native_vector_base::size_factor; +}; - union +/** + * @brief Helper struct to determine the storage type for vector_type + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + * @tparam Enable SFINAE helper + */ +template +struct vector_type_storage; + +/** + * @brief Vector storage type for native scalar types. + * @tparam T The element type of the vector + * @note For Rank = 1 and native types, the storage type is simply T itself (scalar) + */ +template +struct vector_type_storage()>> +{ + using type = T; +}; + +/** + * @brief Vector storage type for native vector types. + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + * + * Assigns a native vector type based on the element type and rank. + * For boolean types, uses a C-style array `T[Rank]`, otherwise uses + * the `NativeVectorT` template specialization. + * + * @note Special handling note: + * Sub-byte sizes such as bool have different sizes in ext_vector_type (via NativeVectorT) vs array + * types due to packing. Builtin vector types pack bool elements, while C++ arrays use 1 byte per + * bool as a standard (minimum write size = 1 byte). e.g., ext_vector_type(bool, 4) is packed as + * minimum 1 byte, while bool[4] is 4 bytes. vector_type::AsType, aliases with + * StaticallyIndexedArray_v2 which is C-style array under the hood, so we need to avoid using + * ext_vector_type with bool due to potential for data slicing errors. + */ +template +struct vector_type_storage() && (Rank > 1)>> +{ + using type = std::conditional_t, T[Rank], NativeVectorT>; + ; +}; + +/** + * @brief Vector storage type for non-native vector types. + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + * @note For non-native types, the storage type is non_native_vector_base + */ +template +struct vector_type_storage()>> +{ + using type = non_native_vector_base; +}; + +/** + * @brief Convenience wrapper for vector_type_storage + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + */ +template +using vector_type_storage_t = typename vector_type_storage::type; + +/** + * @brief Trait to check whether one storage class is the same as another (e.g., same scalar, or + * same vector class). + * @tparam Lhs The source storage type + * @tparam Rhs The comparator storage type + * + * Same storage classes are: + * - Same type + * - Same template vector types with matching base type (may have different ranks) + * - C-style arrays of same base type (may have different ranks) + */ +template +struct is_same_storage_class : public false_type +{ +}; + +/** + * @brief Same type storage class + * @tparam T The storage type + */ +template +struct is_same_storage_class : public true_type +{ +}; + +/** + * @brief Template vector types of same base type with different ranks + * @tparam VecT The vector template class type (e.g., vector_type, NativeVectorT, + * non_native_vector_base) + * @tparam T The base element type + * @tparam LhsRank The rank of the source type + * @tparam RhsRank The rank of the comparator type + */ +template