diff --git a/CMakeLists.txt b/CMakeLists.txt index 49831bfebf9..bf603615068 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -281,6 +281,26 @@ include(CheckOpenSSLIsBoringSSL) include(CheckOpenSSLIsQuictls) include(CheckOpenSSLIsAwsLc) find_package(OpenSSL REQUIRED) +auto_option(SIMDUTF FEATURE_VAR TS_USE_SIMDUTF PACKAGE_DEPENDS simdutf) +# ats_base64_decode uses base64_default_or_url (added in 7.0.0) and the +# decode_up_to_bad_char parameter on base64_to_binary_safe (also 7.0.0). A +# simdutf package older than that would pass find_package and then fail at +# compile time. Disable the feature with a status message in AUTO, hard-error +# if the user asked for it explicitly. +set(_MIN_SIMDUTF_VERSION "7.0.0") +if(TS_USE_SIMDUTF AND simdutf_VERSION VERSION_LESS "${_MIN_SIMDUTF_VERSION}") + if(ENABLE_SIMDUTF STREQUAL "AUTO") + message( + STATUS + "Found simdutf ${simdutf_VERSION} but >= ${_MIN_SIMDUTF_VERSION} required for ats_base64; falling back to scalar" + ) + set(TS_USE_SIMDUTF FALSE) + else() + message(FATAL_ERROR "TS_USE_SIMDUTF requires simdutf >= ${_MIN_SIMDUTF_VERSION} (found ${simdutf_VERSION}). " + "Either upgrade simdutf or configure with -DENABLE_SIMDUTF=OFF." + ) + endif() +endif() check_openssl_is_boringssl(SSLLIB_IS_BORINGSSL BORINGSSL_VERSION "${OPENSSL_INCLUDE_DIR}") check_openssl_is_awslc(SSLLIB_IS_AWSLC AWSLC_VERSION "${OPENSSL_INCLUDE_DIR}") diff --git a/include/tscore/ink_config.h.cmake.in b/include/tscore/ink_config.h.cmake.in index 73c8b860fb9..988ed5cb8d0 100644 --- a/include/tscore/ink_config.h.cmake.in +++ b/include/tscore/ink_config.h.cmake.in @@ -163,6 +163,7 @@ const int DEFAULT_STACKSIZE = @DEFAULT_STACK_SIZE@; #cmakedefine01 TS_USE_POSIX_CAP #cmakedefine01 TS_USE_QUIC #cmakedefine01 TS_USE_REMOTE_UNWINDING +#cmakedefine01 TS_USE_SIMDUTF #cmakedefine01 TS_USE_TLS13 #cmakedefine01 TS_USE_TLS_ASYNC #cmakedefine01 TS_USE_TPROXY diff --git a/src/tscore/CMakeLists.txt b/src/tscore/CMakeLists.txt index 7790adc87dd..effccd0647c 100644 --- a/src/tscore/CMakeLists.txt +++ b/src/tscore/CMakeLists.txt @@ -110,6 +110,10 @@ target_link_libraries( tscore PUBLIC OpenSSL::Crypto libswoc::libswoc yaml-cpp::yaml-cpp systemtap::systemtap resolv::resolv ts::tsutil ) +if(TS_USE_SIMDUTF) + target_link_libraries(tscore PUBLIC simdutf::simdutf) +endif() + if(TS_USE_POSIX_CAP) target_link_libraries(tscore PUBLIC cap::cap) endif() @@ -158,6 +162,7 @@ if(BUILD_TESTING) unit_tests/test_Throttler.cc unit_tests/test_Tokenizer.cc unit_tests/test_arena.cc + unit_tests/test_ink_base64.cc unit_tests/test_ink_inet.cc unit_tests/test_ink_memory.cc unit_tests/test_ink_string.cc diff --git a/src/tscore/ink_base64.cc b/src/tscore/ink_base64.cc index 849d7c8ce83..1a1bab12afb 100644 --- a/src/tscore/ink_base64.cc +++ b/src/tscore/ink_base64.cc @@ -1,6 +1,47 @@ /** @file - A brief file description + Base64 encoding and decoding. + + The public entry points (`ats_base64_encode` / `ats_base64_decode`, also + exposed through `TSBase64Encode` / `TSBase64Decode`) dispatch between two + internal implementations: + + - A hand-rolled scalar path, always present, used directly when + TS_USE_SIMDUTF is disabled, and used for inputs below the SIMD + crossover threshold when TS_USE_SIMDUTF is enabled. The scalar path + avoids simdutf's runtime ISA dispatch and virtual-call overhead, + which would otherwise dominate the cost for tiny inputs (e.g. the + 8-byte SnowflakeID encode). + + - simdutf, used for larger inputs when TS_USE_SIMDUTF is enabled. + simdutf provides SIMD-accelerated kernels and is several times + faster than the scalar path once the input is big enough to amortize + its per-call overhead. + + Thresholds were chosen empirically on a 2.1 GHz Broadwell-EP Xeon + (AVX2) using tools/benchmark/benchmark_ink_base64. The exact crossover + shifts on different cores but lies within an order of magnitude of these + values everywhere we've measured. + + Both paths preserve the same public contract: + + - encode: standard RFC 1521 alphabet (`+`, `/`), `=` padding, no line + breaks, trailing NUL written at outBuffer[length]. + + - decode: accepts both standard (`+`, `/`) and URL-safe (`-`, `_`) + alphabets in the same input; tolerates missing padding; on any + non-alphabet byte (including ASCII whitespace, '=', or garbage), + truncates and returns success with whatever was decoded up to that + point; trailing NUL written at outBuffer[length]; supports in-place + decode (dst == src). + + Decode whitespace alignment: simdutf's forgiving-base64 mode would + silently skip ASCII whitespace and continue. To keep TSBase64Decode + results independent of build configuration and input size, the wrapper + pre-scans the input with the same printableToSixBit table the scalar + path uses and truncates inBufferSize at the first non-alphabet byte + before handing it to either implementation. Both paths therefore see + the same prefix of valid alphabet bytes and produce identical output. @section license License @@ -20,32 +61,63 @@ See the License for the specific language governing permissions and limitations under the License. */ - -/* - * Base64 encoding and decoding as according to RFC1521. Similar to uudecode. - * - * RFC 1521 requires inserting line breaks for long lines. The basic web - * authentication scheme does not require them. This implementation is - * intended for web-related use, and line breaks are not implemented. - * - */ #include "tscore/ink_platform.h" #include "tscore/ink_base64.h" #include "tscore/ink_assert.h" -// TODO: The code here seems a bit klunky, and could probably be improved a bit. +#if TS_USE_SIMDUTF +#include -bool -ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) +// Inputs at or below these byte counts stay on the scalar path, where they +// outrun simdutf's per-call overhead. inBufferSize for encode is the binary +// plaintext length; for decode it is the base64-encoded length. +constexpr size_t BASE64_ENCODE_SIMD_THRESHOLD = 24; +constexpr size_t BASE64_DECODE_SIMD_THRESHOLD = 48; +#endif + +namespace +{ + +/* Converts a printable character to its six bit representation. */ +const unsigned char printableToSixBit[256] = { + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 62, 64, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, + 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 63, + 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64}; + +constexpr unsigned char MAX_PRINT_VAL = 63; + +inline unsigned char +decode_byte(char c) +{ + return printableToSixBit[static_cast(c)]; +} + +// Count the leading base64-alphabet bytes (standard or URL-safe). The result +// is the prefix length that both decode paths actually consume; any byte at +// or after this index is whitespace, '=', or garbage and is dropped. +inline size_t +count_alphabet_prefix(const char *inBuffer, size_t inBufferSize) +{ + size_t valid = 0; + while (valid < inBufferSize && decode_byte(inBuffer[valid]) <= MAX_PRINT_VAL) { + ++valid; + } + return valid; +} + +// Hand-rolled scalar encode. Caller has already validated outBufSize. +void +encode_scalar(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t *length) { static const char _codes[66] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; char *obuf = outBuffer; char in_tail[4]; - if (outBufSize < ats_base64_encode_dstlen(inBufferSize)) { - return false; - } - while (inBufferSize > 2) { *obuf++ = _codes[(inBuffer[0] >> 2) & 077]; *obuf++ = _codes[((inBuffer[0] & 03) << 4) | ((inBuffer[1] >> 4) & 017)]; @@ -56,14 +128,6 @@ ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outB inBuffer += 3; } - /* - * We've done all the input groups of three chars. We're left - * with 0, 1, or 2 input chars. We have to add zero-bits to the - * right if we don't have enough input chars. - * If 0 chars left, we're done. - * If 1 char left, form 2 output chars, and add 2 pad chars to output. - * If 2 chars left, form 3 output chars, add 1 pad char to output. - */ if (inBufferSize == 0) { *obuf = '\0'; if (length) { @@ -88,81 +152,120 @@ ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outB *length = (obuf + 4) - outBuffer; } } +} - return true; +// Hand-rolled scalar decode. Caller has pre-scanned: every byte in +// inBuffer[0..inBufferSize) is in the base64 alphabet (decode_byte() returns +// <= 63). The caller has also validated outBufSize. +// +// This restructures the legacy decode tail handling. The previous code ran +// one extra loop iteration past the alphabet prefix when inBufferSize was in +// {1, 2, 3} (reading inBuffer[2..3] which was either OOB to the caller's +// buffer or past the valid prefix) and then read inBuffer[-2] in the trailing +// adjustment block when no loop iterations had advanced inBuffer. Process +// only complete 4-character groups in the main loop and decode any 2- or +// 3-byte tail explicitly; a 1-byte tail encodes nothing meaningful and is +// dropped, matching what an RFC 4648 decoder is supposed to do. +void +decode_scalar(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t *length) +{ + size_t decodedBytes = 0; + unsigned char *buf = outBuffer; + + while (inBufferSize >= 4) { + buf[0] = static_cast(decode_byte(inBuffer[0]) << 2 | decode_byte(inBuffer[1]) >> 4); + buf[1] = static_cast(decode_byte(inBuffer[1]) << 4 | decode_byte(inBuffer[2]) >> 2); + buf[2] = static_cast(decode_byte(inBuffer[2]) << 6 | decode_byte(inBuffer[3])); + buf += 3; + inBuffer += 4; + decodedBytes += 3; + inBufferSize -= 4; + } + + if (inBufferSize >= 2) { + buf[0] = static_cast(decode_byte(inBuffer[0]) << 2 | decode_byte(inBuffer[1]) >> 4); + decodedBytes += 1; + if (inBufferSize >= 3) { + buf[1] = static_cast(decode_byte(inBuffer[1]) << 4 | decode_byte(inBuffer[2]) >> 2); + decodedBytes += 1; + } + } + + outBuffer[decodedBytes] = '\0'; + if (length) { + *length = decodedBytes; + } } +} // namespace + bool -ats_base64_encode(const char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) +ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) { - return ats_base64_encode(reinterpret_cast(inBuffer), inBufferSize, outBuffer, outBufSize, length); -} + if (outBufSize < ats_base64_encode_dstlen(inBufferSize)) { + return false; + } -/*------------------------------------------------------------------------- - This is a reentrant, and malloc free implementation of ats_base64_decode. - -------------------------------------------------------------------------*/ -#ifdef DECODE -#undef DECODE +#if TS_USE_SIMDUTF + if (inBufferSize > BASE64_ENCODE_SIMD_THRESHOLD) { + size_t written = simdutf::binary_to_base64(reinterpret_cast(inBuffer), inBufferSize, outBuffer); + outBuffer[written] = '\0'; + if (length) { + *length = written; + } + return true; + } #endif -#define DECODE(x) printableToSixBit[(unsigned char)x] -#define MAX_PRINT_VAL 63 + encode_scalar(inBuffer, inBufferSize, outBuffer, length); + return true; +} -/* Converts a printable character to it's six bit representation */ -const unsigned char printableToSixBit[256] = { - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 62, 64, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, - 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 63, - 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64}; +bool +ats_base64_encode(const char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) +{ + return ats_base64_encode(reinterpret_cast(inBuffer), inBufferSize, outBuffer, outBufSize, length); +} bool ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t outBufSize, size_t *length) { - size_t inBytes = 0; - size_t decodedBytes = 0; - unsigned char *buf = outBuffer; - int inputBytesDecoded = 0; - - // Make sure there is sufficient space in the output buffer if (outBufSize < ats_base64_decode_dstlen(inBufferSize)) { return false; } - // Ignore any trailing ='s or other undecodable characters. - // TODO: Perhaps that ought to be an error instead? - while (inBytes < inBufferSize && printableToSixBit[static_cast(inBuffer[inBytes])] <= MAX_PRINT_VAL) { - ++inBytes; - } - - for (size_t i = 0; i < inBytes; i += 4) { - buf[0] = static_cast(DECODE(inBuffer[0]) << 2 | DECODE(inBuffer[1]) >> 4); - buf[1] = static_cast(DECODE(inBuffer[1]) << 4 | DECODE(inBuffer[2]) >> 2); - buf[2] = static_cast(DECODE(inBuffer[2]) << 6 | DECODE(inBuffer[3])); + // Truncate to the leading base64-alphabet prefix. Doing this upfront for + // both paths is what keeps the SIMD and scalar decoders aligned on inputs + // that contain ASCII whitespace, '=' padding, or any other non-alphabet + // byte; otherwise simdutf's forgiving mode would skip whitespace and + // continue while the scalar would have stopped at it. + const size_t valid = count_alphabet_prefix(inBuffer, inBufferSize); - buf += 3; - inBuffer += 4; - decodedBytes += 3; - inputBytesDecoded += 4; - } +#if TS_USE_SIMDUTF + if (valid > BASE64_DECODE_SIMD_THRESHOLD) { + // Reserve one byte for the trailing NUL we always emit. The input we + // pass to simdutf is pure alphabet bytes (no whitespace, no '='), so + // last_chunk_options::loose handles the unpadded tail and + // decode_up_to_bad_char never triggers in practice. + size_t out_len = outBufSize - 1; + auto r = simdutf::base64_to_binary_safe(inBuffer, valid, reinterpret_cast(outBuffer), out_len, + simdutf::base64_default_or_url, simdutf::last_chunk_handling_options::loose, + /*decode_up_to_bad_char=*/true); - // Check to see if we decoded a multiple of 4 four - // bytes - if ((inBytes - inputBytesDecoded) & 0x3) { - if (DECODE(inBuffer[-2]) > MAX_PRINT_VAL) { - decodedBytes -= 2; - } else { - decodedBytes -= 1; + // OUTPUT_BUFFER_TOO_SMALL is impossible given the upfront dstlen check; + // be defensive anyway. + if (r.error == simdutf::error_code::OUTPUT_BUFFER_TOO_SMALL) { + return false; } - } - outBuffer[decodedBytes] = '\0'; - if (length) { - *length = decodedBytes; + outBuffer[out_len] = '\0'; + if (length) { + *length = out_len; + } + return true; } +#endif + decode_scalar(inBuffer, valid, outBuffer, length); return true; } diff --git a/src/tscore/unit_tests/test_ink_base64.cc b/src/tscore/unit_tests/test_ink_base64.cc new file mode 100644 index 00000000000..478edc9b928 --- /dev/null +++ b/src/tscore/unit_tests/test_ink_base64.cc @@ -0,0 +1,242 @@ +/** @file + + Unit tests for ats_base64_encode / ats_base64_decode. + + Runs as part of the standard test_tscore binary so the scalar and + simdutf decode paths are exercised by ctest in every build, not just + when ENABLE_BENCHMARKS is set. The scenarios bracket the SIMD + crossover thresholds (24 bytes for encode, 48 bytes for decode) so + that any future divergence between the two implementations is caught. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#include + +#include "tscore/ink_base64.h" + +#include +#include +#include +#include +#include +#include + +namespace +{ + +std::vector +make_random_bytes(std::size_t n, std::uint64_t seed = 0xC0FFEEULL) +{ + std::mt19937_64 rng(seed); + std::vector v(n); + for (std::size_t i = 0; i < n; ++i) { + v[i] = static_cast(rng() & 0xFFU); + } + return v; +} + +std::string +encode_with_ats(const std::vector &in) +{ + std::string out; + out.resize(ats_base64_encode_dstlen(in.size())); + std::size_t n = 0; + bool ok = ats_base64_encode(in.data(), in.size(), out.data(), out.size(), &n); + REQUIRE(ok); + out.resize(n); + return out; +} + +} // namespace + +// Random round-trip across sizes that bracket both the encode threshold (24) +// and the decode threshold (48), so the scalar and simdutf paths are both +// exercised every run. +TEST_CASE("ats_base64 round-trip across SIMD thresholds", "[ats_base64]") +{ + for (std::size_t sz : std::array{0, 1, 8, 23, 24, 25, 47, 48, 49, 4096}) { + auto input = make_random_bytes(sz, 0xC0FFEE + sz); + auto encoded = encode_with_ats(input); + std::vector decoded(ats_base64_decode_dstlen(encoded.size()) + 1); + std::size_t dec_len = 0; + + CAPTURE(sz); + REQUIRE(ats_base64_decode(encoded.data(), encoded.size(), decoded.data(), decoded.size(), &dec_len)); + REQUIRE(dec_len == sz); + REQUIRE(std::memcmp(decoded.data(), input.data(), sz) == 0); + } +} + +// Byte-exact fixture taken from InkAPITest's SDK_API_ENCODING regression +// test. Any future implementation swap must keep this passing. +TEST_CASE("ats_base64 InkAPITest fixture", "[ats_base64][fixture]") +{ + const char *url = "http://www.example.com/foo?fie= \"#%<>[]\\^`{}~&bar={test}&fum=Apache Traffic Server"; + const char *url_b64 = + "aHR0cDovL3d3dy5leGFtcGxlLmNvbS9mb28/ZmllPSAiIyU8PltdXF5ge31+JmJhcj17dGVzdH0mZnVtPUFwYWNoZSBUcmFmZmljIFNlcnZlcg=="; + const auto url_len = std::strlen(url); + const auto url_b64_len = std::strlen(url_b64); + + SECTION("encode produces byte-identical RFC1521 output with '=' padding") + { + std::array buf{}; + std::size_t enc_len = 0; + REQUIRE(ats_base64_encode(url, url_len, buf.data(), buf.size(), &enc_len)); + REQUIRE(enc_len == url_b64_len); + REQUIRE(std::strcmp(buf.data(), url_b64) == 0); + } + + SECTION("decode reproduces the original byte-for-byte") + { + std::array buf{}; + std::size_t dec_len = 0; + REQUIRE(ats_base64_decode(url_b64, url_b64_len, reinterpret_cast(buf.data()), buf.size(), &dec_len)); + REQUIRE(dec_len == url_len); + REQUIRE(std::strcmp(buf.data(), url) == 0); + } +} + +// The decoder accepts the URL-safe alphabet ('-' for '+', '_' for '/') in +// the same call as standard input. Long enough to exercise the simdutf path. +TEST_CASE("ats_base64_decode accepts URL-safe alphabet", "[ats_base64]") +{ + // Decode the same payload twice, once with standard '+/' and once with + // URL-safe '-_', and require identical output. The 0xfb/0xbf/0xff pattern + // produces '+' and '/' in the encoded form, so the URL-safe substitution + // actually changes bytes. + const std::vector payload = { + 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, + 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, + 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, + }; + + std::string standard = encode_with_ats(payload); + REQUIRE(standard.size() > 48); // encoded form must cross BASE64_DECODE_SIMD_THRESHOLD + std::string url_safe = standard; + for (auto &c : url_safe) { + if (c == '+') { + c = '-'; + } else if (c == '/') { + c = '_'; + } + } + REQUIRE(url_safe != standard); // payload chosen so the swap actually changes bytes + + std::vector out_std(ats_base64_decode_dstlen(standard.size()) + 1); + std::vector out_url(ats_base64_decode_dstlen(url_safe.size()) + 1); + std::size_t len_std = 0; + std::size_t len_url = 0; + REQUIRE(ats_base64_decode(standard.data(), standard.size(), out_std.data(), out_std.size(), &len_std)); + REQUIRE(ats_base64_decode(url_safe.data(), url_safe.size(), out_url.data(), out_url.size(), &len_url)); + + REQUIRE(len_std == payload.size()); + REQUIRE(len_url == payload.size()); + REQUIRE(std::memcmp(out_std.data(), payload.data(), payload.size()) == 0); + REQUIRE(std::memcmp(out_url.data(), payload.data(), payload.size()) == 0); +} + +// In-place decode (dst == src) must produce the same result as decoding into +// a separate buffer. Used by plugins/experimental/magick. +TEST_CASE("ats_base64_decode supports in-place (dst == src)", "[ats_base64]") +{ + for (std::size_t sz : std::array{1, 16, 24, 47, 48, 200}) { + auto input = make_random_bytes(sz, 0xBADF00D + sz); + std::string encoded = encode_with_ats(input); + const std::size_t enc_size = encoded.size(); + // The in-place buffer must hold both the encoded input AND the trailing + // NUL the decoder writes; encoded.size() is always >= the decoded size, + // so one extra byte for the NUL is enough. + std::string in_place = encoded; + in_place.resize(enc_size + 1); + + std::vector reference(ats_base64_decode_dstlen(encoded.size()) + 1); + std::size_t ref_len = 0; + std::size_t ip_len = 0; + + REQUIRE(ats_base64_decode(encoded.data(), enc_size, reference.data(), reference.size(), &ref_len)); + REQUIRE( + ats_base64_decode(in_place.data(), enc_size, reinterpret_cast(in_place.data()), in_place.size(), &ip_len)); + + CAPTURE(sz); + REQUIRE(ip_len == ref_len); + REQUIRE(std::memcmp(in_place.data(), reference.data(), ref_len) == 0); + } +} + +// A non-alphabet byte mid-input truncates: the decoder must stop at the +// first such byte and return the bytes decoded up to that point. This was +// the documented behavior of the legacy scalar path and the simdutf wrapper +// preserves it by pre-scanning the input. +TEST_CASE("ats_base64_decode truncates at first non-alphabet byte", "[ats_base64]") +{ + // 28 alphabet bytes then a stop byte then more alphabet (we expect the + // tail to be ignored). 28 chars decode to 21 bytes. Length is below the + // simdutf threshold so this exercises the scalar path. + const char *input = "AAAAAAAAAAAAAAAAAAAAAAAAAAAA!!!IGNORED-TAIL-IGNORED-TAIL"; + + std::array out{}; + std::size_t len = 0; + REQUIRE(ats_base64_decode(input, std::strlen(input), out.data(), out.size(), &len)); + REQUIRE(len == 21); + for (std::size_t i = 0; i < len; ++i) { + REQUIRE(out[i] == 0); // 'A' = base64 index 0, so 28 'A's decode to 21 zero bytes + } +} + +// Whitespace in the input should be treated like any other non-alphabet +// byte: the decoder stops at it. This is the property that keeps the SIMD +// and scalar paths aligned regardless of input length, since simdutf would +// otherwise silently skip whitespace and continue. +TEST_CASE("ats_base64_decode stops at ASCII whitespace", "[ats_base64]") +{ + // Construct an input long enough to cross the simdutf threshold so we + // exercise the wrapper's pre-scan, with a tab byte planted mid-buffer. + std::string input; + input.assign(60, 'A'); // 60 'A's -> 45 zero bytes if fully decoded + input[40] = '\t'; // first whitespace at index 40 -> 30 bytes after truncation + + std::array out{}; + std::size_t len = 0; + REQUIRE(ats_base64_decode(input.data(), input.size(), out.data(), out.size(), &len)); + REQUIRE(len == 30); +} + +// 1, 2, and 3 base64 chars decode to 0, 1, and 2 bytes respectively. Previous +// code had OOB reads in this path; this case guards against a regression. +TEST_CASE("ats_base64_decode handles very short alphabet inputs", "[ats_base64]") +{ + std::array out{}; + std::size_t len = 0; + + // 1 alphabet char: encodes nothing meaningful, decoded length is 0. + REQUIRE(ats_base64_decode("A", 1, out.data(), out.size(), &len)); + REQUIRE(len == 0); + + // 2 alphabet chars: decoded length is 1. + REQUIRE(ats_base64_decode("AA", 2, out.data(), out.size(), &len)); + REQUIRE(len == 1); + REQUIRE(out[0] == 0); + + // 3 alphabet chars: decoded length is 2. + REQUIRE(ats_base64_decode("AAA", 3, out.data(), out.size(), &len)); + REQUIRE(len == 2); + REQUIRE(out[0] == 0); + REQUIRE(out[1] == 0); +} diff --git a/tools/benchmark/CMakeLists.txt b/tools/benchmark/CMakeLists.txt index 49f25fad1c1..d6fa658ca9e 100644 --- a/tools/benchmark/CMakeLists.txt +++ b/tools/benchmark/CMakeLists.txt @@ -36,6 +36,9 @@ target_link_libraries(benchmark_SharedMutex PRIVATE Catch2::Catch2 ts::tscore li add_executable(benchmark_Random benchmark_Random.cc) target_link_libraries(benchmark_Random PRIVATE Catch2::Catch2WithMain ts::tscore) +add_executable(benchmark_ink_base64 benchmark_ink_base64.cc) +target_link_libraries(benchmark_ink_base64 PRIVATE Catch2::Catch2WithMain ts::tscore) + add_executable(benchmark_HostDB benchmark_HostDB.cc) target_link_libraries( benchmark_HostDB diff --git a/tools/benchmark/benchmark_ink_base64.cc b/tools/benchmark/benchmark_ink_base64.cc new file mode 100644 index 00000000000..8e620ee8206 --- /dev/null +++ b/tools/benchmark/benchmark_ink_base64.cc @@ -0,0 +1,137 @@ +/** @file + + Throughput benchmark for ats_base64_encode / ats_base64_decode comparing + the scalar path against the simdutf-backed path. + + Sizes bracket both the scalar↔SIMD crossover thresholds (24 bytes for + encode, 48 bytes for decode) and the typical caller sizes inside ATS + (8-byte SnowflakeID, 20-32 byte HMACs, ~200 byte OCSP DER requests, + larger payloads for ceiling measurements). Correctness is covered by + src/tscore/unit_tests/test_ink_base64.cc, which runs under ctest in + every build. + + Catch::Benchmark::keep_memory is used around each call to prevent the + optimizer from DCE-ing the inlined output buffer writes past the first + observed byte. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#define CATCH_CONFIG_ENABLE_BENCHMARKING + +#include +#include +#include + +#include "tscore/ink_base64.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace +{ + +// Sizes chosen to mirror real callers and to bracket the scalar↔SIMD +// crossover. +// 8B - SnowflakeID (uint64_t) +// 16-48B - HMAC-SHA1/SHA256 and crossover region for encode +// 64-96B - crossover region for decode +// 200B - typical OCSP DER request (RFC6960 caps at 255B encoded) +// 512B / 4KB - stress the inner loop where SIMD wins most +constexpr std::array kPayloadSizes{8, 16, 24, 32, 48, 64, 96, 200, 512, 4096}; + +std::vector +make_random_bytes(size_t n, uint64_t seed = 0xC0FFEEULL) +{ + std::mt19937_64 rng(seed); + std::vector v(n); + for (size_t i = 0; i < n; ++i) { + v[i] = static_cast(rng() & 0xFFU); + } + return v; +} + +std::string +encode_with_ats(const std::vector &in) +{ + std::string out; + out.resize(ats_base64_encode_dstlen(in.size())); + size_t n = 0; + bool ok = ats_base64_encode(in.data(), in.size(), out.data(), out.size(), &n); + REQUIRE(ok); + out.resize(n); + return out; +} + +} // namespace + +TEST_CASE("active base64 configuration", "[base64][config]") +{ + // Print whether simdutf is wired in so the benchmark output makes the + // selected configuration obvious. + std::cout << "ats_base64 compiled with: "; +#if TS_USE_SIMDUTF + std::cout << "simdutf hybrid (scalar <= 24/48B, simdutf above)"; +#else + std::cout << "scalar only"; +#endif + std::cout << '\n'; + SUCCEED(); +} + +TEST_CASE("ats_base64_encode throughput", "[bench][base64][encode]") +{ + for (size_t sz : kPayloadSizes) { + auto input = make_random_bytes(sz); + std::vector output(ats_base64_encode_dstlen(sz) + 16); + + std::string name = "encode " + std::to_string(sz) + "B"; + BENCHMARK(name.c_str()) + { + size_t out_len = 0; + bool ok = ats_base64_encode(input.data(), input.size(), output.data(), output.size(), &out_len); + Catch::Benchmark::keep_memory(output.data()); + return ok ? out_len : size_t{0}; + }; + } +} + +TEST_CASE("ats_base64_decode throughput", "[bench][base64][decode]") +{ + for (size_t sz : kPayloadSizes) { + auto input = make_random_bytes(sz); + auto encoded = encode_with_ats(input); + std::vector output(ats_base64_decode_dstlen(encoded.size()) + 16); + + // Name reports the *plaintext* size so it lines up with the encode bench. + std::string name = "decode " + std::to_string(sz) + "B (" + std::to_string(encoded.size()) + "B b64)"; + BENCHMARK(name.c_str()) + { + size_t out_len = 0; + bool ok = ats_base64_decode(encoded.data(), encoded.size(), output.data(), output.size(), &out_len); + Catch::Benchmark::keep_memory(output.data()); + return ok ? out_len : size_t{0}; + }; + } +}