diff --git a/Common/CMakeLists.txt b/Common/CMakeLists.txt index f435e269575aa..0b92758e45f43 100644 --- a/Common/CMakeLists.txt +++ b/Common/CMakeLists.txt @@ -16,5 +16,6 @@ add_subdirectory(Types) add_subdirectory(Utils) add_subdirectory(SimConfig) add_subdirectory(DCAFitter) +add_subdirectory(ML) o2_data_file(COPY maps DESTINATION Common) diff --git a/Common/ML/CMakeLists.txt b/Common/ML/CMakeLists.txt new file mode 100644 index 0000000000000..74287e774efa1 --- /dev/null +++ b/Common/ML/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright 2019-2020 CERN and copyright holders of ALICE O2. +# See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +# All rights not expressly granted are reserved. +# +# This software is distributed under the terms of the GNU General Public +# License v3 (GPL Version 3), copied verbatim in the file "COPYING". +# +# In applying this license CERN does not waive the privileges and immunities +# granted to it by virtue of its status as an Intergovernmental Organization +# or submit itself to any jurisdiction. + +o2_add_library(ML + SOURCES src/ort_interface.cxx + TARGETVARNAME targetName + PRIVATE_LINK_LIBRARIES O2::Framework ONNXRuntime::ONNXRuntime) diff --git a/Common/ML/include/ML/3rdparty/GPUORTFloat16.h b/Common/ML/include/ML/3rdparty/GPUORTFloat16.h new file mode 100644 index 0000000000000..db65328409d3c --- /dev/null +++ b/Common/ML/include/ML/3rdparty/GPUORTFloat16.h @@ -0,0 +1,867 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This code was created from: +// - https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_float16.h +// - https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_cxx_api.h + +#include +#include +#include +#include + +namespace o2 +{ + +namespace OrtDataType +{ + +namespace detail +{ + +enum class endian { +#if defined(_WIN32) + little = 0, + big = 1, + native = little, +#elif defined(__GNUC__) || defined(__clang__) + little = __ORDER_LITTLE_ENDIAN__, + big = __ORDER_BIG_ENDIAN__, + native = __BYTE_ORDER__, +#else +#error OrtDataType::detail::endian is not implemented in this environment. +#endif +}; + +static_assert( + endian::native == endian::little || endian::native == endian::big, + "Only little-endian or big-endian native byte orders are supported."); + +} // namespace detail + +/// +/// Shared implementation between public and internal classes. CRTP pattern. +/// +template +struct Float16Impl { + protected: + /// + /// Converts from float to uint16_t float16 representation + /// + /// + /// + constexpr static uint16_t ToUint16Impl(float v) noexcept; + + /// + /// Converts float16 to float + /// + /// float representation of float16 value + float ToFloatImpl() const noexcept; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + uint16_t AbsImpl() const noexcept + { + return static_cast(val & ~kSignMask); + } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + uint16_t NegateImpl() const noexcept + { + return IsNaN() ? val : static_cast(val ^ kSignMask); + } + + public: + // uint16_t special values + static constexpr uint16_t kSignMask = 0x8000U; + static constexpr uint16_t kBiasedExponentMask = 0x7C00U; + static constexpr uint16_t kPositiveInfinityBits = 0x7C00U; + static constexpr uint16_t kNegativeInfinityBits = 0xFC00U; + static constexpr uint16_t kPositiveQNaNBits = 0x7E00U; + static constexpr uint16_t kNegativeQNaNBits = 0xFE00U; + static constexpr uint16_t kEpsilonBits = 0x4170U; + static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number + static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number + static constexpr uint16_t kOneBits = 0x3C00U; + static constexpr uint16_t kMinusOneBits = 0xBC00U; + + uint16_t val{0}; + + Float16Impl() = default; + + /// + /// Checks if the value is negative + /// + /// true if negative + bool IsNegative() const noexcept + { + return static_cast(val) < 0; + } + + /// + /// Tests if the value is NaN + /// + /// true if NaN + bool IsNaN() const noexcept + { + return AbsImpl() > kPositiveInfinityBits; + } + + /// + /// Tests if the value is finite + /// + /// true if finite + bool IsFinite() const noexcept + { + return AbsImpl() < kPositiveInfinityBits; + } + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + bool IsPositiveInfinity() const noexcept + { + return val == kPositiveInfinityBits; + } + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + bool IsNegativeInfinity() const noexcept + { + return val == kNegativeInfinityBits; + } + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + bool IsInfinity() const noexcept + { + return AbsImpl() == kPositiveInfinityBits; + } + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + bool IsNaNOrZero() const noexcept + { + auto abs = AbsImpl(); + return (abs == 0 || abs > kPositiveInfinityBits); + } + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + bool IsNormal() const noexcept + { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) + } + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + bool IsSubnormal() const noexcept + { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) + } + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept + { + return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; + } + + bool operator==(const Float16Impl& rhs) const noexcept + { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is not equal to anything, including itself. + return false; + } + return val == rhs.val; + } + + bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); } + + bool operator<(const Float16Impl& rhs) const noexcept + { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is unordered with respect to everything, including itself. + return false; + } + + const bool left_is_negative = IsNegative(); + if (left_is_negative != rhs.IsNegative()) { + // When the signs of left and right differ, we know that left is less than right if it is + // the negative value. The exception to this is if both values are zero, in which case IEEE + // says they should be equal, even if the signs differ. + return left_is_negative && !AreZero(*this, rhs); + } + return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); + } +}; + +// The following Float16_t conversions are based on the code from +// Eigen library. + +// The conversion routines are Copyright (c) Fabian Giesen, 2016. +// The original license follows: +// +// Copyright (c) Fabian Giesen, 2016 +// All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted. +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +namespace detail +{ +union float32_bits { + unsigned int u; + float f; +}; +}; // namespace detail + +template +inline constexpr uint16_t Float16Impl::ToUint16Impl(float v) noexcept +{ + detail::float32_bits f{}; + f.f = v; + + constexpr detail::float32_bits f32infty = {255 << 23}; + constexpr detail::float32_bits f16max = {(127 + 16) << 23}; + constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; + constexpr unsigned int sign_mask = 0x80000000u; + uint16_t val = static_cast(0x0u); + + unsigned int sign = f.u & sign_mask; + f.u ^= sign; + + // NOTE all the integer compares in this function can be safely + // compiled into signed compares since all operands are below + // 0x80000000. Important if you want fast straight SSE2 code + // (since there's no unsigned PCMPGTD). + + if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) + val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { // (De)normalized number or zero + if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.f += denorm_magic.f; + + // and one integer subtract of the bias later, we have our final float! + val = static_cast(f.u - denorm_magic.u); + } else { + unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but + // without arithmetic overflow. + f.u += 0xc8000fffU; + // rounding bias part 2 + f.u += mant_odd; + // take the bits! + val = static_cast(f.u >> 13); + } + } + + val |= static_cast(sign >> 16); + return val; +} + +template +inline float Float16Impl::ToFloatImpl() const noexcept +{ + constexpr detail::float32_bits magic = {113 << 23}; + constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift + detail::float32_bits o{}; + + o.u = (val & 0x7fff) << 13; // exponent/mantissa bits + unsigned int exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // re-normalize + } + + // Attempt to workaround the Internal Compiler Error on ARM64 + // for bitwise | operator, including std::bitset +#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC) + if (IsNegative()) { + return -o.f; + } +#else + // original code: + o.u |= (val & 0x8000U) << 16U; // sign bit +#endif + return o.f; +} + +/// Shared implementation between public and internal classes. CRTP pattern. +template +struct BFloat16Impl { + protected: + /// + /// Converts from float to uint16_t float16 representation + /// + /// + /// + static uint16_t ToUint16Impl(float v) noexcept; + + /// + /// Converts bfloat16 to float + /// + /// float representation of bfloat16 value + float ToFloatImpl() const noexcept; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + uint16_t AbsImpl() const noexcept + { + return static_cast(val & ~kSignMask); + } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + uint16_t NegateImpl() const noexcept + { + return IsNaN() ? val : static_cast(val ^ kSignMask); + } + + public: + // uint16_t special values + static constexpr uint16_t kSignMask = 0x8000U; + static constexpr uint16_t kBiasedExponentMask = 0x7F80U; + static constexpr uint16_t kPositiveInfinityBits = 0x7F80U; + static constexpr uint16_t kNegativeInfinityBits = 0xFF80U; + static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U; + static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U; + static constexpr uint16_t kSignaling_NaNBits = 0x7F80U; + static constexpr uint16_t kEpsilonBits = 0x0080U; + static constexpr uint16_t kMinValueBits = 0xFF7FU; + static constexpr uint16_t kMaxValueBits = 0x7F7FU; + static constexpr uint16_t kRoundToNearest = 0x7FFFU; + static constexpr uint16_t kOneBits = 0x3F80U; + static constexpr uint16_t kMinusOneBits = 0xBF80U; + + uint16_t val{0}; + + BFloat16Impl() = default; + + /// + /// Checks if the value is negative + /// + /// true if negative + bool IsNegative() const noexcept + { + return static_cast(val) < 0; + } + + /// + /// Tests if the value is NaN + /// + /// true if NaN + bool IsNaN() const noexcept + { + return AbsImpl() > kPositiveInfinityBits; + } + + /// + /// Tests if the value is finite + /// + /// true if finite + bool IsFinite() const noexcept + { + return AbsImpl() < kPositiveInfinityBits; + } + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + bool IsPositiveInfinity() const noexcept + { + return val == kPositiveInfinityBits; + } + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + bool IsNegativeInfinity() const noexcept + { + return val == kNegativeInfinityBits; + } + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + bool IsInfinity() const noexcept + { + return AbsImpl() == kPositiveInfinityBits; + } + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + bool IsNaNOrZero() const noexcept + { + auto abs = AbsImpl(); + return (abs == 0 || abs > kPositiveInfinityBits); + } + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + bool IsNormal() const noexcept + { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) + } + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + bool IsSubnormal() const noexcept + { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) + } + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept + { + // IEEE defines that positive and negative zero are equal, this gives us a quick equality check + // for two values by or'ing the private bits together and stripping the sign. They are both zero, + // and therefore equivalent, if the resulting value is still zero. + return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; + } +}; + +template +inline uint16_t BFloat16Impl::ToUint16Impl(float v) noexcept +{ + uint16_t result; + if (std::isnan(v)) { + result = kPositiveQNaNBits; + } else { + auto get_msb_half = [](float fl) { + uint16_t result; +#ifdef __cpp_if_constexpr + if constexpr (detail::endian::native == detail::endian::little) +#else + if (detail::endian::native == detail::endian::little) +#endif + { + std::memcpy(&result, reinterpret_cast(&fl) + sizeof(uint16_t), sizeof(uint16_t)); + } else { + std::memcpy(&result, &fl, sizeof(uint16_t)); + } + return result; + }; + + uint16_t upper_bits = get_msb_half(v); + union { + uint32_t U32; + float F32; + }; + F32 = v; + U32 += (upper_bits & 1) + kRoundToNearest; + result = get_msb_half(F32); + } + return result; +} + +template +inline float BFloat16Impl::ToFloatImpl() const noexcept +{ + if (IsNaN()) { + return std::numeric_limits::quiet_NaN(); + } + float result; + char* const first = reinterpret_cast(&result); + char* const second = first + sizeof(uint16_t); +#ifdef __cpp_if_constexpr + if constexpr (detail::endian::native == detail::endian::little) +#else + if (detail::endian::native == detail::endian::little) +#endif + { + std::memset(first, 0, sizeof(uint16_t)); + std::memcpy(second, &val, sizeof(uint16_t)); + } else { + std::memcpy(first, &val, sizeof(uint16_t)); + std::memset(second, 0, sizeof(uint16_t)); + } + return result; +} + +/** \brief IEEE 754 half-precision floating point data type + * + * \details This struct is used for converting float to float16 and back + * so the user could feed inputs and fetch outputs using these type. + * + * The size of the structure should align with uint16_t and one can freely cast + * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. + * + * \code{.unparsed} + * // This example demonstrates converion from float to float16 + * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; + * std::vector fp16_values; + * fp16_values.reserve(std::size(values)); + * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values), + * [](float value) { return Ort::Float16_t(value); }); + * + * \endcode + */ +struct Float16_t : OrtDataType::Float16Impl { + private: + /// + /// Constructor from a 16-bit representation of a float16 value + /// No conversion is done here. + /// + /// 16-bit representation + constexpr explicit Float16_t(uint16_t v) noexcept { val = v; } + + public: + using Base = OrtDataType::Float16Impl; + + /// + /// Default constructor + /// + Float16_t() = default; + + /// + /// Explicit conversion to uint16_t representation of float16. + /// + /// uint16_t bit representation of float16 + /// new instance of Float16_t + constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); } + + /// + /// __ctor from float. Float is converted into float16 16-bit representation. + /// + /// float value + explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); } + + /// + /// Converts float16 to float + /// + /// float representation of float16 value + float ToFloat() const noexcept { return Base::ToFloatImpl(); } + + /// + /// Checks if the value is negative + /// + /// true if negative + using Base::IsNegative; + + /// + /// Tests if the value is NaN + /// + /// true if NaN + using Base::IsNaN; + + /// + /// Tests if the value is finite + /// + /// true if finite + using Base::IsFinite; + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + using Base::IsPositiveInfinity; + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + using Base::IsNegativeInfinity; + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + using Base::IsInfinity; + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + using Base::IsNaNOrZero; + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + using Base::IsNormal; + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + using Base::IsSubnormal; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + using Base::Abs; + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + using Base::Negate; + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + using Base::AreZero; + + /// + /// User defined conversion operator. Converts Float16_t to float. + /// + explicit operator float() const noexcept { return ToFloat(); } + + using Base::operator==; + using Base::operator!=; + using Base::operator<; +}; + +static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); + +/** \brief bfloat16 (Brain Floating Point) data type + * + * \details This struct is used for converting float to bfloat16 and back + * so the user could feed inputs and fetch outputs using these type. + * + * The size of the structure should align with uint16_t and one can freely cast + * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. + * + * \code{.unparsed} + * // This example demonstrates converion from float to float16 + * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; + * std::vector bfp16_values; + * bfp16_values.reserve(std::size(values)); + * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values), + * [](float value) { return Ort::BFloat16_t(value); }); + * + * \endcode + */ +struct BFloat16_t : OrtDataType::BFloat16Impl { + private: + /// + /// Constructor from a uint16_t representation of bfloat16 + /// used in FromBits() to escape overload resolution issue with + /// constructor from float. + /// No conversion is done. + /// + /// 16-bit bfloat16 value + constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; } + + public: + using Base = OrtDataType::BFloat16Impl; + + BFloat16_t() = default; + + /// + /// Explicit conversion to uint16_t representation of bfloat16. + /// + /// uint16_t bit representation of bfloat16 + /// new instance of BFloat16_t + static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); } + + /// + /// __ctor from float. Float is converted into bfloat16 16-bit representation. + /// + /// float value + explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); } + + /// + /// Converts bfloat16 to float + /// + /// float representation of bfloat16 value + float ToFloat() const noexcept { return Base::ToFloatImpl(); } + + /// + /// Checks if the value is negative + /// + /// true if negative + using Base::IsNegative; + + /// + /// Tests if the value is NaN + /// + /// true if NaN + using Base::IsNaN; + + /// + /// Tests if the value is finite + /// + /// true if finite + using Base::IsFinite; + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + using Base::IsPositiveInfinity; + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + using Base::IsNegativeInfinity; + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + using Base::IsInfinity; + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + using Base::IsNaNOrZero; + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + using Base::IsNormal; + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + using Base::IsSubnormal; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + using Base::Abs; + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + using Base::Negate; + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + using Base::AreZero; + + /// + /// User defined conversion operator. Converts BFloat16_t to float. + /// + explicit operator float() const noexcept { return ToFloat(); } + + // We do not have an inherited impl for the below operators + // as the internal class implements them a little differently + bool operator==(const BFloat16_t& rhs) const noexcept; + bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); } + bool operator<(const BFloat16_t& rhs) const noexcept; +}; + +static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); + +} // namespace OrtDataType + +} // namespace o2 \ No newline at end of file diff --git a/Common/ML/include/ML/ort_interface.h b/Common/ML/include/ML/ort_interface.h new file mode 100644 index 0000000000000..e2049b8508cb4 --- /dev/null +++ b/Common/ML/include/ML/ort_interface.h @@ -0,0 +1,92 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +/// \file ort_interface.h +/// \author Christian Sonnabend +/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU + +#ifndef O2_ML_ONNX_INTERFACE_H +#define O2_ML_ONNX_INTERFACE_H + +// C++ and system includes +#include +#include +#include +#include +#include + +// O2 includes +#include "Framework/Logger.h" + +namespace o2 +{ + +namespace ml +{ + +class OrtModel +{ + + public: + // Constructor + OrtModel() = default; + OrtModel(std::unordered_map optionsMap) { reset(optionsMap); } + void init(std::unordered_map optionsMap) { reset(optionsMap); } + void reset(std::unordered_map); + + virtual ~OrtModel() = default; + + // Conversion + template + std::vector v2v(std::vector&, bool = true); + + // Inferencing + template // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h + std::vector inference(std::vector&); + + template // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h + std::vector inference(std::vector>&); + + // template // class I is the input data type, e.g. float, class T the throughput data type and class O is the output data type + // std::vector inference(std::vector&); + + // Reset session + void resetSession(); + + std::vector> getNumInputNodes() const { return mInputShapes; } + std::vector> getNumOutputNodes() const { return mOutputShapes; } + std::vector getInputNames() const { return mInputNames; } + std::vector getOutputNames() const { return mOutputNames; } + + void setActiveThreads(int threads) { intraOpNumThreads = threads; } + + private: + // ORT variables -> need to be hidden as Pimpl + struct OrtVariables; + OrtVariables* pImplOrt; + + // Input & Output specifications of the loaded network + std::vector inputNamesChar, outputNamesChar; + std::vector mInputNames, mOutputNames; + std::vector> mInputShapes, mOutputShapes; + + // Environment settings + std::string modelPath, device = "cpu", dtype = "float"; // device options should be cpu, rocm, migraphx, cuda + int intraOpNumThreads = 0, deviceId = 0, enableProfiling = 0, loggingLevel = 0, allocateDeviceMemory = 0, enableOptimizations = 0; + + std::string printShape(const std::vector&); +}; + +} // namespace ml + +} // namespace o2 + +#endif // O2_ML_ORT_INTERFACE_H diff --git a/Common/ML/src/ort_interface.cxx b/Common/ML/src/ort_interface.cxx new file mode 100644 index 0000000000000..27ac8eee16b7b --- /dev/null +++ b/Common/ML/src/ort_interface.cxx @@ -0,0 +1,280 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +/// \file ort_interface.cxx +/// \author Christian Sonnabend +/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU + +#include "ML/ort_interface.h" +#include "ML/3rdparty/GPUORTFloat16.h" + +// ONNX includes +#include + +namespace o2 +{ + +namespace ml +{ + +struct OrtModel::OrtVariables { // The actual implementation is hidden in the .cxx file + // ORT runtime objects + Ort::RunOptions runOptions; + std::shared_ptr env = nullptr; + std::shared_ptr session = nullptr; ///< ONNX session + Ort::SessionOptions sessionOptions; + Ort::AllocatorWithDefaultOptions allocator; + Ort::MemoryInfo memoryInfo = Ort::MemoryInfo("Cpu", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); +}; + +void OrtModel::reset(std::unordered_map optionsMap) +{ + + pImplOrt = new OrtVariables(); + + // Load from options map + if (!optionsMap.contains("model-path")) { + LOG(fatal) << "(ORT) Model path cannot be empty!"; + } + modelPath = optionsMap["model-path"]; + device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU"); + dtype = (optionsMap.contains("dtype") ? optionsMap["dtype"] : "float"); + deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0); + allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0); + intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0); + loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0); + enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0); + enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0); + + std::string dev_mem_str = "Hip"; +#ifdef ORT_ROCM_BUILD + if (device == "ROCM") { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId)); + LOG(info) << "(ORT) ROCM execution provider set"; + } +#endif +#ifdef ORT_MIGRAPHX_BUILD + if (device == "MIGRAPHX") { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId)); + LOG(info) << "(ORT) MIGraphX execution provider set"; + } +#endif +#ifdef ORT_CUDA_BUILD + if (device == "CUDA") { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId)); + LOG(info) << "(ORT) CUDA execution provider set"; + dev_mem_str = "Cuda"; + } +#endif + + if (allocateDeviceMemory) { + pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault); + LOG(info) << "(ORT) Memory info set to on-device memory"; + } + + if (device == "CPU") { + (pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads); + if (intraOpNumThreads > 1) { + (pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL); + } else if (intraOpNumThreads == 1) { + (pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); + } + LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " threads"; + } + + (pImplOrt->sessionOptions).DisableMemPattern(); + (pImplOrt->sessionOptions).DisableCpuMemArena(); + + if (enableProfiling) { + if (optionsMap.contains("profiling-output-path")) { + (pImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str()); + } else { + LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now."; + (pImplOrt->sessionOptions).DisableProfiling(); + } + } else { + (pImplOrt->sessionOptions).DisableProfiling(); + } + (pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations)); + (pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel)); + + pImplOrt->env = std::make_shared(OrtLoggingLevel(loggingLevel), (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str())); + pImplOrt->session = std::make_shared(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions); + + for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) { + mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get()); + } + for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) { + mInputShapes.emplace_back((pImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape()); + } + for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) { + mOutputNames.push_back((pImplOrt->session)->GetOutputNameAllocated(i, pImplOrt->allocator).get()); + } + for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) { + mOutputShapes.emplace_back((pImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape()); + } + + inputNamesChar.resize(mInputNames.size(), nullptr); + std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar), + [&](const std::string& str) { return str.c_str(); }); + outputNamesChar.resize(mOutputNames.size(), nullptr); + std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar), + [&](const std::string& str) { return str.c_str(); }); + + // Print names + if (loggingLevel > 1) { + LOG(info) << "Input Nodes:"; + for (size_t i = 0; i < mInputNames.size(); i++) { + LOG(info) << "\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]); + } + + LOG(info) << "Output Nodes:"; + for (size_t i = 0; i < mOutputNames.size(); i++) { + LOG(info) << "\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]); + } + } +} + +void OrtModel::resetSession() +{ + pImplOrt->session = std::make_shared(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions); +} + +template +std::vector OrtModel::v2v(std::vector& input, bool clearInput) +{ + if constexpr (std::is_same_v) { + return input; + } else { + std::vector output(input.size()); + std::transform(std::begin(input), std::end(input), std::begin(output), [](I f) { return O(f); }); + if (clearInput) { + input.clear(); + } + return output; + } +} + +template // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h +std::vector OrtModel::inference(std::vector& input) +{ + std::vector inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; + std::vector inputTensor; + inputTensor.emplace_back(Ort::Value::CreateTensor(pImplOrt->memoryInfo, reinterpret_cast(input.data()), input.size(), inputShape.data(), inputShape.size())); + // input.clear(); + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); + O* outputValues = reinterpret_cast(outputTensors[0].template GetTensorMutableData()); + std::vector outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]}; + outputTensors.clear(); + return outputValuesVec; +} + +template // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h +std::vector OrtModel::inference(std::vector>& input) +{ + std::vector inputTensor; + for (auto i : input) { + std::vector inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; + inputTensor.emplace_back(Ort::Value::CreateTensor(pImplOrt->memoryInfo, reinterpret_cast(i.data()), i.size(), inputShape.data(), inputShape.size())); + } + // input.clear(); + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); + O* outputValues = reinterpret_cast(outputTensors[0].template GetTensorMutableData()); + std::vector outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]}; + outputTensors.clear(); + return outputValuesVec; +} + +std::string OrtModel::printShape(const std::vector& v) +{ + std::stringstream ss(""); + for (size_t i = 0; i < v.size() - 1; i++) { + ss << v[i] << "x"; + } + ss << v[v.size() - 1]; + return ss.str(); +} + +template <> +std::vector OrtModel::inference(std::vector& input) +{ + std::vector inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; + std::vector inputTensor; + inputTensor.emplace_back(Ort::Value::CreateTensor(pImplOrt->memoryInfo, input.data(), input.size(), inputShape.data(), inputShape.size())); + // input.clear(); + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); + float* outputValues = outputTensors[0].template GetTensorMutableData(); + std::vector outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]}; + outputTensors.clear(); + return outputValuesVec; +} + +template <> +std::vector OrtModel::inference(std::vector& input) +{ + std::vector inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; + std::vector inputTensor; + inputTensor.emplace_back(Ort::Value::CreateTensor(pImplOrt->memoryInfo, reinterpret_cast(input.data()), input.size(), inputShape.data(), inputShape.size())); + // input.clear(); + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); + float* outputValues = outputTensors[0].template GetTensorMutableData(); + std::vector outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]}; + outputTensors.clear(); + return outputValuesVec; +} + +template <> +std::vector OrtModel::inference(std::vector& input) +{ + std::vector inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; + std::vector inputTensor; + inputTensor.emplace_back(Ort::Value::CreateTensor(pImplOrt->memoryInfo, reinterpret_cast(input.data()), input.size(), inputShape.data(), inputShape.size())); + // input.clear(); + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); + OrtDataType::Float16_t* outputValues = reinterpret_cast(outputTensors[0].template GetTensorMutableData()); + std::vector outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]}; + outputTensors.clear(); + return outputValuesVec; +} + +template <> +std::vector OrtModel::inference(std::vector& input) +{ + std::vector inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; + std::vector inputTensor; + inputTensor.emplace_back(Ort::Value::CreateTensor(pImplOrt->memoryInfo, reinterpret_cast(input.data()), input.size(), inputShape.data(), inputShape.size())); + // input.clear(); + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); + OrtDataType::Float16_t* outputValues = reinterpret_cast(outputTensors[0].template GetTensorMutableData()); + std::vector outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]}; + outputTensors.clear(); + return outputValuesVec; +} + +template <> +std::vector OrtModel::inference(std::vector>& input) +{ + std::vector inputTensor; + for (auto i : input) { + std::vector inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]}; + inputTensor.emplace_back(Ort::Value::CreateTensor(pImplOrt->memoryInfo, reinterpret_cast(i.data()), i.size(), inputShape.data(), inputShape.size())); + } + // input.clear(); + auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size()); + OrtDataType::Float16_t* outputValues = reinterpret_cast(outputTensors[0].template GetTensorMutableData()); + std::vector outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]}; + outputTensors.clear(); + return outputValuesVec; +} + +} // namespace ml + +} // namespace o2