From 24c69fbb8f903fba1d8b135bb97a41e2d7f37578 Mon Sep 17 00:00:00 2001 From: Phong Nguyen Date: Thu, 14 May 2026 19:16:24 +0000 Subject: [PATCH 1/2] tscore: optional simdutf path for ats_base64 encode/decode The hand-rolled base64 implementation in ink_base64.cc is a measurable hotspot in places that encode or decode larger payloads (OCSP DER requests, S3 auth HMACs, signed URL segments). simdutf provides SIMD-accelerated kernels that run roughly an order of magnitude faster on medium-and-larger inputs on AVX2/AVX-512 hardware. Wire simdutf in as an opt-in dependency through the existing auto_option machinery (ENABLE_SIMDUTF, default AUTO). When the package is available, the wrapper dispatches to simdutf for inputs above an empirically chosen threshold and keeps the scalar path for smaller inputs, where simdutf's per-call overhead would otherwise be a regression (notably the 8-byte SnowflakeID encode). Both paths preserve the existing public contract: standard '+/=' encode alphabet, accepts both '+/' and '-_' on decode in the same call, tolerates missing padding, truncates silently on invalid input, and always writes a trailing NUL. A new microbenchmark under tools/benchmark locks the InkAPITest SDK_API_ENCODING fixture as a regression test and provides the throughput numbers used to choose the thresholds. Co-Authored-By: Claude Opus 4.7 (1M context) --- CMakeLists.txt | 1 + include/tscore/ink_config.h.cmake.in | 1 + src/tscore/CMakeLists.txt | 4 + src/tscore/ink_base64.cc | 212 ++++++++++++++++-------- tools/benchmark/CMakeLists.txt | 3 + tools/benchmark/benchmark_ink_base64.cc | 205 +++++++++++++++++++++++ 6 files changed, 357 insertions(+), 69 deletions(-) create mode 100644 tools/benchmark/benchmark_ink_base64.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 49831bfebf9..333a86b137d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -281,6 +281,7 @@ include(CheckOpenSSLIsBoringSSL) include(CheckOpenSSLIsQuictls) include(CheckOpenSSLIsAwsLc) find_package(OpenSSL REQUIRED) +auto_option(SIMDUTF FEATURE_VAR TS_USE_SIMDUTF PACKAGE_DEPENDS simdutf) check_openssl_is_boringssl(SSLLIB_IS_BORINGSSL BORINGSSL_VERSION "${OPENSSL_INCLUDE_DIR}") check_openssl_is_awslc(SSLLIB_IS_AWSLC AWSLC_VERSION "${OPENSSL_INCLUDE_DIR}") diff --git a/include/tscore/ink_config.h.cmake.in b/include/tscore/ink_config.h.cmake.in index 73c8b860fb9..988ed5cb8d0 100644 --- a/include/tscore/ink_config.h.cmake.in +++ b/include/tscore/ink_config.h.cmake.in @@ -163,6 +163,7 @@ const int DEFAULT_STACKSIZE = @DEFAULT_STACK_SIZE@; #cmakedefine01 TS_USE_POSIX_CAP #cmakedefine01 TS_USE_QUIC #cmakedefine01 TS_USE_REMOTE_UNWINDING +#cmakedefine01 TS_USE_SIMDUTF #cmakedefine01 TS_USE_TLS13 #cmakedefine01 TS_USE_TLS_ASYNC #cmakedefine01 TS_USE_TPROXY diff --git a/src/tscore/CMakeLists.txt b/src/tscore/CMakeLists.txt index 7790adc87dd..6fa2b55594c 100644 --- a/src/tscore/CMakeLists.txt +++ b/src/tscore/CMakeLists.txt @@ -110,6 +110,10 @@ target_link_libraries( tscore PUBLIC OpenSSL::Crypto libswoc::libswoc yaml-cpp::yaml-cpp systemtap::systemtap resolv::resolv ts::tsutil ) +if(TS_USE_SIMDUTF) + target_link_libraries(tscore PUBLIC simdutf::simdutf) +endif() + if(TS_USE_POSIX_CAP) target_link_libraries(tscore PUBLIC cap::cap) endif() diff --git a/src/tscore/ink_base64.cc b/src/tscore/ink_base64.cc index 849d7c8ce83..bf581569496 100644 --- a/src/tscore/ink_base64.cc +++ b/src/tscore/ink_base64.cc @@ -1,6 +1,41 @@ /** @file - A brief file description + Base64 encoding and decoding. + + The public entry points (`ats_base64_encode` / `ats_base64_decode`, also + exposed through `TSBase64Encode` / `TSBase64Decode`) dispatch between two + internal implementations: + + - A hand-rolled scalar path, always present, used directly when + TS_USE_SIMDUTF is disabled, and used for inputs below the SIMD + crossover threshold when TS_USE_SIMDUTF is enabled. The scalar path + avoids simdutf's runtime ISA dispatch and virtual-call overhead, + which would otherwise dominate the cost for tiny inputs (e.g. the + 8-byte SnowflakeID encode). + + - simdutf, used for larger inputs when TS_USE_SIMDUTF is enabled. + simdutf provides SIMD-accelerated kernels and is several times + faster than the scalar path once the input is big enough to amortize + its per-call overhead. + + Thresholds were chosen empirically on a 2.1 GHz Broadwell-EP Xeon + (AVX2) using tools/benchmark/benchmark_ink_base64. The exact crossover + shifts on different cores but lies within an order of magnitude of these + values everywhere we've measured. + + Both paths preserve the same public contract: + + - encode: standard RFC 1521 alphabet (`+`, `/`), `=` padding, no line + breaks, trailing NUL written at outBuffer[length]. + - decode: accepts both standard (`+`, `/`) and URL-safe (`-`, `_`) + alphabets in the same input; tolerates missing padding; on an + invalid character, truncates and returns success with whatever was + decoded up to that point; trailing NUL written at outBuffer[length]. + + Note: simdutf's forgiving-base64 mode silently skips ASCII whitespace + (space, tab, CR, LF, FF) inside the input, whereas the scalar path + treats whitespace as an end-of-input marker. No caller in-tree feeds + whitespace to these functions. @section license License @@ -20,32 +55,50 @@ See the License for the specific language governing permissions and limitations under the License. */ - -/* - * Base64 encoding and decoding as according to RFC1521. Similar to uudecode. - * - * RFC 1521 requires inserting line breaks for long lines. The basic web - * authentication scheme does not require them. This implementation is - * intended for web-related use, and line breaks are not implemented. - * - */ #include "tscore/ink_platform.h" #include "tscore/ink_base64.h" #include "tscore/ink_assert.h" -// TODO: The code here seems a bit klunky, and could probably be improved a bit. +#if TS_USE_SIMDUTF +#include -bool -ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) +// Inputs at or below these byte counts stay on the scalar path, where they +// outrun simdutf's per-call overhead. inBufferSize for encode is the binary +// plaintext length; for decode it is the base64-encoded length. +constexpr size_t BASE64_ENCODE_SIMD_THRESHOLD = 24; +constexpr size_t BASE64_DECODE_SIMD_THRESHOLD = 48; +#endif + +namespace +{ + +/* Converts a printable character to its six bit representation. */ +const unsigned char printableToSixBit[256] = { + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 62, 64, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, + 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 63, + 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64}; + +constexpr unsigned char MAX_PRINT_VAL = 63; + +inline unsigned char +decode_byte(char c) +{ + return printableToSixBit[static_cast(c)]; +} + +// Hand-rolled scalar encode. Caller has already validated outBufSize. +void +encode_scalar(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t *length) { static const char _codes[66] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; char *obuf = outBuffer; char in_tail[4]; - if (outBufSize < ats_base64_encode_dstlen(inBufferSize)) { - return false; - } - while (inBufferSize > 2) { *obuf++ = _codes[(inBuffer[0] >> 2) & 077]; *obuf++ = _codes[((inBuffer[0] & 03) << 4) | ((inBuffer[1] >> 4) & 017)]; @@ -56,14 +109,6 @@ ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outB inBuffer += 3; } - /* - * We've done all the input groups of three chars. We're left - * with 0, 1, or 2 input chars. We have to add zero-bits to the - * right if we don't have enough input chars. - * If 0 chars left, we're done. - * If 1 char left, form 2 output chars, and add 2 pad chars to output. - * If 2 chars left, form 3 output chars, add 1 pad char to output. - */ if (inBufferSize == 0) { *obuf = '\0'; if (length) { @@ -88,60 +133,26 @@ ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outB *length = (obuf + 4) - outBuffer; } } - - return true; -} - -bool -ats_base64_encode(const char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) -{ - return ats_base64_encode(reinterpret_cast(inBuffer), inBufferSize, outBuffer, outBufSize, length); } -/*------------------------------------------------------------------------- - This is a reentrant, and malloc free implementation of ats_base64_decode. - -------------------------------------------------------------------------*/ -#ifdef DECODE -#undef DECODE -#endif - -#define DECODE(x) printableToSixBit[(unsigned char)x] -#define MAX_PRINT_VAL 63 - -/* Converts a printable character to it's six bit representation */ -const unsigned char printableToSixBit[256] = { - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 62, 64, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, - 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 63, - 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64}; - -bool -ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t outBufSize, size_t *length) +// Hand-rolled scalar decode. Caller has already validated outBufSize. +void +decode_scalar(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t *length) { size_t inBytes = 0; size_t decodedBytes = 0; unsigned char *buf = outBuffer; int inputBytesDecoded = 0; - // Make sure there is sufficient space in the output buffer - if (outBufSize < ats_base64_decode_dstlen(inBufferSize)) { - return false; - } - // Ignore any trailing ='s or other undecodable characters. - // TODO: Perhaps that ought to be an error instead? - while (inBytes < inBufferSize && printableToSixBit[static_cast(inBuffer[inBytes])] <= MAX_PRINT_VAL) { + while (inBytes < inBufferSize && decode_byte(inBuffer[inBytes]) <= MAX_PRINT_VAL) { ++inBytes; } for (size_t i = 0; i < inBytes; i += 4) { - buf[0] = static_cast(DECODE(inBuffer[0]) << 2 | DECODE(inBuffer[1]) >> 4); - buf[1] = static_cast(DECODE(inBuffer[1]) << 4 | DECODE(inBuffer[2]) >> 2); - buf[2] = static_cast(DECODE(inBuffer[2]) << 6 | DECODE(inBuffer[3])); + buf[0] = static_cast(decode_byte(inBuffer[0]) << 2 | decode_byte(inBuffer[1]) >> 4); + buf[1] = static_cast(decode_byte(inBuffer[1]) << 4 | decode_byte(inBuffer[2]) >> 2); + buf[2] = static_cast(decode_byte(inBuffer[2]) << 6 | decode_byte(inBuffer[3])); buf += 3; inBuffer += 4; @@ -149,10 +160,10 @@ ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outB inputBytesDecoded += 4; } - // Check to see if we decoded a multiple of 4 four - // bytes + // If the consumed input wasn't a multiple of 4 we over-counted the last + // group; trim the trailing 1 or 2 bytes back off. if ((inBytes - inputBytesDecoded) & 0x3) { - if (DECODE(inBuffer[-2]) > MAX_PRINT_VAL) { + if (decode_byte(inBuffer[-2]) > MAX_PRINT_VAL) { decodedBytes -= 2; } else { decodedBytes -= 1; @@ -163,6 +174,69 @@ ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outB if (length) { *length = decodedBytes; } +} + +} // namespace + +bool +ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) +{ + if (outBufSize < ats_base64_encode_dstlen(inBufferSize)) { + return false; + } + +#if TS_USE_SIMDUTF + if (inBufferSize > BASE64_ENCODE_SIMD_THRESHOLD) { + size_t written = simdutf::binary_to_base64(reinterpret_cast(inBuffer), inBufferSize, outBuffer); + outBuffer[written] = '\0'; + if (length) { + *length = written; + } + return true; + } +#endif + + encode_scalar(inBuffer, inBufferSize, outBuffer, length); + return true; +} + +bool +ats_base64_encode(const char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) +{ + return ats_base64_encode(reinterpret_cast(inBuffer), inBufferSize, outBuffer, outBufSize, length); +} + +bool +ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t outBufSize, size_t *length) +{ + if (outBufSize < ats_base64_decode_dstlen(inBufferSize)) { + return false; + } + +#if TS_USE_SIMDUTF + if (inBufferSize > BASE64_DECODE_SIMD_THRESHOLD) { + // Reserve one byte for the trailing NUL we always emit. + size_t out_len = outBufSize - 1; + auto r = simdutf::base64_to_binary_safe(inBuffer, inBufferSize, reinterpret_cast(outBuffer), out_len, + simdutf::base64_default_or_url, simdutf::last_chunk_handling_options::loose, + /*decode_up_to_bad_char=*/true); + + // OUTPUT_BUFFER_TOO_SMALL is impossible given the upfront dstlen check; + // be defensive anyway. INVALID_BASE64_CHARACTER is expected: scalar + // behavior truncated at bad chars without surfacing an error, so we do + // the same. + if (r.error == simdutf::error_code::OUTPUT_BUFFER_TOO_SMALL) { + return false; + } + + outBuffer[out_len] = '\0'; + if (length) { + *length = out_len; + } + return true; + } +#endif + decode_scalar(inBuffer, inBufferSize, outBuffer, length); return true; } diff --git a/tools/benchmark/CMakeLists.txt b/tools/benchmark/CMakeLists.txt index 49f25fad1c1..d6fa658ca9e 100644 --- a/tools/benchmark/CMakeLists.txt +++ b/tools/benchmark/CMakeLists.txt @@ -36,6 +36,9 @@ target_link_libraries(benchmark_SharedMutex PRIVATE Catch2::Catch2 ts::tscore li add_executable(benchmark_Random benchmark_Random.cc) target_link_libraries(benchmark_Random PRIVATE Catch2::Catch2WithMain ts::tscore) +add_executable(benchmark_ink_base64 benchmark_ink_base64.cc) +target_link_libraries(benchmark_ink_base64 PRIVATE Catch2::Catch2WithMain ts::tscore) + add_executable(benchmark_HostDB benchmark_HostDB.cc) target_link_libraries( benchmark_HostDB diff --git a/tools/benchmark/benchmark_ink_base64.cc b/tools/benchmark/benchmark_ink_base64.cc new file mode 100644 index 00000000000..bec34456ab8 --- /dev/null +++ b/tools/benchmark/benchmark_ink_base64.cc @@ -0,0 +1,205 @@ +/** @file + + Micro benchmark for ats_base64_encode / ats_base64_decode and the bulk + scalar tolower path used by URL canonicalization. Establishes a baseline + prior to any SIMD work. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#define CATCH_CONFIG_ENABLE_BENCHMARKING + +#include +#include + +#include "tscore/ink_base64.h" +#include "tscore/ParseRules.h" + +#include +#include +#include +#include +#include +#include + +namespace +{ + +// Sizes chosen to mirror real callers and to bracket the scalar↔SIMD +// crossover. +// 8B - SnowflakeID (uint64_t) +// 16-48B - HMAC-SHA1/SHA256 and crossover region for encode +// 64-128B - crossover region for decode +// 200B - typical OCSP DER request (RFC6960 caps at 255B encoded) +// 512B / 4KB - stress the inner loop where SIMD wins most +constexpr std::array kPayloadSizes{8, 16, 24, 32, 48, 64, 96, 200, 512, 4096}; + +std::vector +make_random_bytes(size_t n, uint64_t seed = 0xC0FFEEULL) +{ + std::mt19937_64 rng(seed); + std::vector v(n); + for (size_t i = 0; i < n; ++i) { + v[i] = static_cast(rng() & 0xFFU); + } + return v; +} + +std::string +encode_with_ats(const std::vector &in) +{ + std::string out; + out.resize(ats_base64_encode_dstlen(in.size())); + size_t n = 0; + bool ok = ats_base64_encode(in.data(), in.size(), out.data(), out.size(), &n); + REQUIRE(ok); + out.resize(n); + return out; +} + +std::vector +make_mixed_case_ascii(size_t n, uint64_t seed = 0xABCDEFULL) +{ + std::mt19937_64 rng(seed); + std::vector v(n); + for (size_t i = 0; i < n; ++i) { + // Mix of uppercase, lowercase, and a few non-letter bytes that should + // pass through tolower unchanged. Models a URL/header byte stream. + auto r = static_cast(rng() & 0x3FU); + if (r < 26U) { + v[i] = static_cast('A' + r); + } else if (r < 52U) { + v[i] = static_cast('a' + (r - 26U)); + } else { + static constexpr char kNonAlpha[] = "0123456789-_./:"; + v[i] = kNonAlpha[r % (sizeof(kNonAlpha) - 1U)]; + } + } + return v; +} + +// Equivalent of the static inline memcpy_tolower() in src/proxy/hdrs/URL.cc. +// Reproduced here because that definition has internal linkage and isn't +// reachable from this TU. +inline void +memcpy_tolower_scalar(char *d, const char *s, int n) +{ + while (n--) { + *d = ParseRules::ink_tolower(*s); + ++s; + ++d; + } +} + +} // namespace + +TEST_CASE("ats_base64 round-trip correctness", "[base64][correctness]") +{ + for (size_t sz : kPayloadSizes) { + auto input = make_random_bytes(sz); + auto encoded = encode_with_ats(input); + std::vector decoded(ats_base64_decode_dstlen(encoded.size()) + 1); + size_t dec_len = 0; + REQUIRE(ats_base64_decode(encoded.data(), encoded.size(), decoded.data(), decoded.size(), &dec_len)); + REQUIRE(dec_len == sz); + REQUIRE(std::memcmp(decoded.data(), input.data(), sz) == 0); + } +} + +// Lock the same byte-exact fixture used by InkAPITest's SDK_API_ENCODING +// regression test. Any future implementation swap must keep this passing. +TEST_CASE("ats_base64 InkAPITest fixture", "[base64][correctness][fixture]") +{ + const char *url = "http://www.example.com/foo?fie= \"#%<>[]\\^`{}~&bar={test}&fum=Apache Traffic Server"; + const char *url_b64 = + "aHR0cDovL3d3dy5leGFtcGxlLmNvbS9mb28/ZmllPSAiIyU8PltdXF5ge31+JmJhcj17dGVzdH0mZnVtPUFwYWNoZSBUcmFmZmljIFNlcnZlcg=="; + const auto url_len = std::strlen(url); + const auto url_b64_len = std::strlen(url_b64); + + SECTION("encode produces byte-identical RFC1521 output with '=' padding") + { + std::array buf{}; + size_t enc_len = 0; + REQUIRE(ats_base64_encode(url, url_len, buf.data(), buf.size(), &enc_len)); + REQUIRE(enc_len == url_b64_len); + REQUIRE(std::strcmp(buf.data(), url_b64) == 0); + } + + SECTION("decode reproduces the original byte-for-byte") + { + std::array buf{}; + size_t dec_len = 0; + REQUIRE(ats_base64_decode(url_b64, url_b64_len, reinterpret_cast(buf.data()), buf.size(), &dec_len)); + REQUIRE(dec_len == url_len); + REQUIRE(std::strcmp(buf.data(), url) == 0); + } +} + +TEST_CASE("ats_base64_encode throughput", "[bench][base64][encode]") +{ + for (size_t sz : kPayloadSizes) { + auto input = make_random_bytes(sz); + std::vector output(ats_base64_encode_dstlen(sz) + 16); + + std::string name = "encode " + std::to_string(sz) + "B"; + BENCHMARK(name.c_str()) + { + size_t out_len = 0; + bool ok = ats_base64_encode(input.data(), input.size(), output.data(), output.size(), &out_len); + // Return a value that depends on the work to prevent DCE. + return ok ? out_len : size_t{0}; + }; + } +} + +TEST_CASE("ats_base64_decode throughput", "[bench][base64][decode]") +{ + for (size_t sz : kPayloadSizes) { + auto input = make_random_bytes(sz); + auto encoded = encode_with_ats(input); + std::vector output(ats_base64_decode_dstlen(encoded.size()) + 16); + + // Name reports the *plaintext* size so it lines up with the encode bench. + std::string name = "decode " + std::to_string(sz) + "B (" + std::to_string(encoded.size()) + "B b64)"; + BENCHMARK(name.c_str()) + { + size_t out_len = 0; + bool ok = ats_base64_decode(encoded.data(), encoded.size(), output.data(), output.size(), &out_len); + return ok ? out_len : size_t{0}; + }; + } +} + +TEST_CASE("memcpy_tolower throughput", "[bench][tolower]") +{ + // Sizes chosen to model URL paths / header names / cache-key segments. + constexpr std::array kTolowerSizes{16, 64, 256, 1024}; + + for (size_t sz : kTolowerSizes) { + auto input = make_mixed_case_ascii(sz); + std::vector output(sz); + + std::string name = "tolower " + std::to_string(sz) + "B"; + BENCHMARK(name.c_str()) + { + memcpy_tolower_scalar(output.data(), input.data(), static_cast(sz)); + return output[0]; + }; + } +} From 0b36457fe2685b996df086033ecdfef883a9cd54 Mon Sep 17 00:00:00 2001 From: Phong Nguyen Date: Thu, 21 May 2026 21:43:58 +0000 Subject: [PATCH 2/2] tscore/base64: address #13166 review feedback - CMakeLists.txt: require simdutf >= 7.0.0. ats_base64_decode uses base64_default_or_url and the decode_up_to_bad_char parameter, both of which landed in simdutf 7.0.0. Without this pin, an older simdutf passes find_package and then fails at compile time. ENABLE_SIMDUTF in AUTO mode silently falls back to the scalar path when the found simdutf is too old; ENABLE_SIMDUTF=ON hard-errors so the user knows their explicit request cannot be satisfied (Copilot). - ink_base64.cc: align the simdutf and scalar decode paths on whitespace. simdutf's forgiving mode silently skips ASCII whitespace and continues; the scalar treats whitespace as end-of-input. With the two paths gated by an input-size threshold, this made TSBase64Decode results depend on build configuration. Pre-scan the input with the same printableToSixBit table upfront and truncate inBufferSize at the first non-alphabet byte before either path runs, so both see the same prefix of alphabet bytes (Copilot). - ink_base64.cc: restructure the scalar decode tail. The previous code ran one extra loop iteration past the alphabet prefix when there were 1..3 trailing alphabet bytes (reading inBuffer[2..3] which was either OOB to the caller or past the prefix) and then read inBuffer[-2] in the trailing adjustment block when no iterations had advanced inBuffer. Process only complete 4-character groups in the main loop and decode any 2- or 3-byte tail explicitly; a 1-byte tail encodes nothing meaningful and is dropped. This was flagged as a known follow-up when the PR landed. - src/tscore/unit_tests/test_ink_base64.cc: new unit test under test_tscore so the scalar and simdutf paths are covered by ctest in every build. Bracketing sizes 0/1/8/23/24/25/47/48/49/4096 exercise both implementations and the threshold transitions. Adds focused cases for URL-safe alphabet decode, in-place decode (dst == src), invalid-byte truncation, whitespace truncation (validates the new alignment), the InkAPITest fixture, and the 1-/2-/3-char tail cases that the scalar restructure now handles cleanly (Copilot). - tools/benchmark/benchmark_ink_base64.cc: rewrite the file header to describe what the bench actually does (scalar-vs-simdutf throughput comparison) and drop the correctness TEST_CASEs that moved to the unit test. Add Catch::Benchmark::keep_memory barriers so the inlined buffer writes aren't DCEd past the first observed byte, and a config-print case that prints whether simdutf is wired in (Copilot). Co-Authored-By: Claude Opus 4.7 (1M context) --- CMakeLists.txt | 19 ++ src/tscore/CMakeLists.txt | 1 + src/tscore/ink_base64.cc | 115 +++++++---- src/tscore/unit_tests/test_ink_base64.cc | 242 +++++++++++++++++++++++ tools/benchmark/benchmark_ink_base64.cc | 126 +++--------- 5 files changed, 363 insertions(+), 140 deletions(-) create mode 100644 src/tscore/unit_tests/test_ink_base64.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 333a86b137d..bf603615068 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -282,6 +282,25 @@ include(CheckOpenSSLIsQuictls) include(CheckOpenSSLIsAwsLc) find_package(OpenSSL REQUIRED) auto_option(SIMDUTF FEATURE_VAR TS_USE_SIMDUTF PACKAGE_DEPENDS simdutf) +# ats_base64_decode uses base64_default_or_url (added in 7.0.0) and the +# decode_up_to_bad_char parameter on base64_to_binary_safe (also 7.0.0). A +# simdutf package older than that would pass find_package and then fail at +# compile time. Disable the feature with a status message in AUTO, hard-error +# if the user asked for it explicitly. +set(_MIN_SIMDUTF_VERSION "7.0.0") +if(TS_USE_SIMDUTF AND simdutf_VERSION VERSION_LESS "${_MIN_SIMDUTF_VERSION}") + if(ENABLE_SIMDUTF STREQUAL "AUTO") + message( + STATUS + "Found simdutf ${simdutf_VERSION} but >= ${_MIN_SIMDUTF_VERSION} required for ats_base64; falling back to scalar" + ) + set(TS_USE_SIMDUTF FALSE) + else() + message(FATAL_ERROR "TS_USE_SIMDUTF requires simdutf >= ${_MIN_SIMDUTF_VERSION} (found ${simdutf_VERSION}). " + "Either upgrade simdutf or configure with -DENABLE_SIMDUTF=OFF." + ) + endif() +endif() check_openssl_is_boringssl(SSLLIB_IS_BORINGSSL BORINGSSL_VERSION "${OPENSSL_INCLUDE_DIR}") check_openssl_is_awslc(SSLLIB_IS_AWSLC AWSLC_VERSION "${OPENSSL_INCLUDE_DIR}") diff --git a/src/tscore/CMakeLists.txt b/src/tscore/CMakeLists.txt index 6fa2b55594c..effccd0647c 100644 --- a/src/tscore/CMakeLists.txt +++ b/src/tscore/CMakeLists.txt @@ -162,6 +162,7 @@ if(BUILD_TESTING) unit_tests/test_Throttler.cc unit_tests/test_Tokenizer.cc unit_tests/test_arena.cc + unit_tests/test_ink_base64.cc unit_tests/test_ink_inet.cc unit_tests/test_ink_memory.cc unit_tests/test_ink_string.cc diff --git a/src/tscore/ink_base64.cc b/src/tscore/ink_base64.cc index bf581569496..1a1bab12afb 100644 --- a/src/tscore/ink_base64.cc +++ b/src/tscore/ink_base64.cc @@ -27,15 +27,21 @@ - encode: standard RFC 1521 alphabet (`+`, `/`), `=` padding, no line breaks, trailing NUL written at outBuffer[length]. - - decode: accepts both standard (`+`, `/`) and URL-safe (`-`, `_`) - alphabets in the same input; tolerates missing padding; on an - invalid character, truncates and returns success with whatever was - decoded up to that point; trailing NUL written at outBuffer[length]. - Note: simdutf's forgiving-base64 mode silently skips ASCII whitespace - (space, tab, CR, LF, FF) inside the input, whereas the scalar path - treats whitespace as an end-of-input marker. No caller in-tree feeds - whitespace to these functions. + - decode: accepts both standard (`+`, `/`) and URL-safe (`-`, `_`) + alphabets in the same input; tolerates missing padding; on any + non-alphabet byte (including ASCII whitespace, '=', or garbage), + truncates and returns success with whatever was decoded up to that + point; trailing NUL written at outBuffer[length]; supports in-place + decode (dst == src). + + Decode whitespace alignment: simdutf's forgiving-base64 mode would + silently skip ASCII whitespace and continue. To keep TSBase64Decode + results independent of build configuration and input size, the wrapper + pre-scans the input with the same printableToSixBit table the scalar + path uses and truncates inBufferSize at the first non-alphabet byte + before handing it to either implementation. Both paths therefore see + the same prefix of valid alphabet bytes and produce identical output. @section license License @@ -91,6 +97,19 @@ decode_byte(char c) return printableToSixBit[static_cast(c)]; } +// Count the leading base64-alphabet bytes (standard or URL-safe). The result +// is the prefix length that both decode paths actually consume; any byte at +// or after this index is whitespace, '=', or garbage and is dropped. +inline size_t +count_alphabet_prefix(const char *inBuffer, size_t inBufferSize) +{ + size_t valid = 0; + while (valid < inBufferSize && decode_byte(inBuffer[valid]) <= MAX_PRINT_VAL) { + ++valid; + } + return valid; +} + // Hand-rolled scalar encode. Caller has already validated outBufSize. void encode_scalar(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t *length) @@ -135,42 +154,44 @@ encode_scalar(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffe } } -// Hand-rolled scalar decode. Caller has already validated outBufSize. +// Hand-rolled scalar decode. Caller has pre-scanned: every byte in +// inBuffer[0..inBufferSize) is in the base64 alphabet (decode_byte() returns +// <= 63). The caller has also validated outBufSize. +// +// This restructures the legacy decode tail handling. The previous code ran +// one extra loop iteration past the alphabet prefix when inBufferSize was in +// {1, 2, 3} (reading inBuffer[2..3] which was either OOB to the caller's +// buffer or past the valid prefix) and then read inBuffer[-2] in the trailing +// adjustment block when no loop iterations had advanced inBuffer. Process +// only complete 4-character groups in the main loop and decode any 2- or +// 3-byte tail explicitly; a 1-byte tail encodes nothing meaningful and is +// dropped, matching what an RFC 4648 decoder is supposed to do. void decode_scalar(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t *length) { - size_t inBytes = 0; - size_t decodedBytes = 0; - unsigned char *buf = outBuffer; - int inputBytesDecoded = 0; - - // Ignore any trailing ='s or other undecodable characters. - while (inBytes < inBufferSize && decode_byte(inBuffer[inBytes]) <= MAX_PRINT_VAL) { - ++inBytes; + size_t decodedBytes = 0; + unsigned char *buf = outBuffer; + + while (inBufferSize >= 4) { + buf[0] = static_cast(decode_byte(inBuffer[0]) << 2 | decode_byte(inBuffer[1]) >> 4); + buf[1] = static_cast(decode_byte(inBuffer[1]) << 4 | decode_byte(inBuffer[2]) >> 2); + buf[2] = static_cast(decode_byte(inBuffer[2]) << 6 | decode_byte(inBuffer[3])); + buf += 3; + inBuffer += 4; + decodedBytes += 3; + inBufferSize -= 4; } - for (size_t i = 0; i < inBytes; i += 4) { - buf[0] = static_cast(decode_byte(inBuffer[0]) << 2 | decode_byte(inBuffer[1]) >> 4); - buf[1] = static_cast(decode_byte(inBuffer[1]) << 4 | decode_byte(inBuffer[2]) >> 2); - buf[2] = static_cast(decode_byte(inBuffer[2]) << 6 | decode_byte(inBuffer[3])); - - buf += 3; - inBuffer += 4; - decodedBytes += 3; - inputBytesDecoded += 4; - } - - // If the consumed input wasn't a multiple of 4 we over-counted the last - // group; trim the trailing 1 or 2 bytes back off. - if ((inBytes - inputBytesDecoded) & 0x3) { - if (decode_byte(inBuffer[-2]) > MAX_PRINT_VAL) { - decodedBytes -= 2; - } else { - decodedBytes -= 1; + if (inBufferSize >= 2) { + buf[0] = static_cast(decode_byte(inBuffer[0]) << 2 | decode_byte(inBuffer[1]) >> 4); + decodedBytes += 1; + if (inBufferSize >= 3) { + buf[1] = static_cast(decode_byte(inBuffer[1]) << 4 | decode_byte(inBuffer[2]) >> 2); + decodedBytes += 1; } } - outBuffer[decodedBytes] = '\0'; + outBuffer[decodedBytes] = '\0'; if (length) { *length = decodedBytes; } @@ -213,18 +234,26 @@ ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outB return false; } + // Truncate to the leading base64-alphabet prefix. Doing this upfront for + // both paths is what keeps the SIMD and scalar decoders aligned on inputs + // that contain ASCII whitespace, '=' padding, or any other non-alphabet + // byte; otherwise simdutf's forgiving mode would skip whitespace and + // continue while the scalar would have stopped at it. + const size_t valid = count_alphabet_prefix(inBuffer, inBufferSize); + #if TS_USE_SIMDUTF - if (inBufferSize > BASE64_DECODE_SIMD_THRESHOLD) { - // Reserve one byte for the trailing NUL we always emit. + if (valid > BASE64_DECODE_SIMD_THRESHOLD) { + // Reserve one byte for the trailing NUL we always emit. The input we + // pass to simdutf is pure alphabet bytes (no whitespace, no '='), so + // last_chunk_options::loose handles the unpadded tail and + // decode_up_to_bad_char never triggers in practice. size_t out_len = outBufSize - 1; - auto r = simdutf::base64_to_binary_safe(inBuffer, inBufferSize, reinterpret_cast(outBuffer), out_len, + auto r = simdutf::base64_to_binary_safe(inBuffer, valid, reinterpret_cast(outBuffer), out_len, simdutf::base64_default_or_url, simdutf::last_chunk_handling_options::loose, /*decode_up_to_bad_char=*/true); // OUTPUT_BUFFER_TOO_SMALL is impossible given the upfront dstlen check; - // be defensive anyway. INVALID_BASE64_CHARACTER is expected: scalar - // behavior truncated at bad chars without surfacing an error, so we do - // the same. + // be defensive anyway. if (r.error == simdutf::error_code::OUTPUT_BUFFER_TOO_SMALL) { return false; } @@ -237,6 +266,6 @@ ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outB } #endif - decode_scalar(inBuffer, inBufferSize, outBuffer, length); + decode_scalar(inBuffer, valid, outBuffer, length); return true; } diff --git a/src/tscore/unit_tests/test_ink_base64.cc b/src/tscore/unit_tests/test_ink_base64.cc new file mode 100644 index 00000000000..478edc9b928 --- /dev/null +++ b/src/tscore/unit_tests/test_ink_base64.cc @@ -0,0 +1,242 @@ +/** @file + + Unit tests for ats_base64_encode / ats_base64_decode. + + Runs as part of the standard test_tscore binary so the scalar and + simdutf decode paths are exercised by ctest in every build, not just + when ENABLE_BENCHMARKS is set. The scenarios bracket the SIMD + crossover thresholds (24 bytes for encode, 48 bytes for decode) so + that any future divergence between the two implementations is caught. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#include + +#include "tscore/ink_base64.h" + +#include +#include +#include +#include +#include +#include + +namespace +{ + +std::vector +make_random_bytes(std::size_t n, std::uint64_t seed = 0xC0FFEEULL) +{ + std::mt19937_64 rng(seed); + std::vector v(n); + for (std::size_t i = 0; i < n; ++i) { + v[i] = static_cast(rng() & 0xFFU); + } + return v; +} + +std::string +encode_with_ats(const std::vector &in) +{ + std::string out; + out.resize(ats_base64_encode_dstlen(in.size())); + std::size_t n = 0; + bool ok = ats_base64_encode(in.data(), in.size(), out.data(), out.size(), &n); + REQUIRE(ok); + out.resize(n); + return out; +} + +} // namespace + +// Random round-trip across sizes that bracket both the encode threshold (24) +// and the decode threshold (48), so the scalar and simdutf paths are both +// exercised every run. +TEST_CASE("ats_base64 round-trip across SIMD thresholds", "[ats_base64]") +{ + for (std::size_t sz : std::array{0, 1, 8, 23, 24, 25, 47, 48, 49, 4096}) { + auto input = make_random_bytes(sz, 0xC0FFEE + sz); + auto encoded = encode_with_ats(input); + std::vector decoded(ats_base64_decode_dstlen(encoded.size()) + 1); + std::size_t dec_len = 0; + + CAPTURE(sz); + REQUIRE(ats_base64_decode(encoded.data(), encoded.size(), decoded.data(), decoded.size(), &dec_len)); + REQUIRE(dec_len == sz); + REQUIRE(std::memcmp(decoded.data(), input.data(), sz) == 0); + } +} + +// Byte-exact fixture taken from InkAPITest's SDK_API_ENCODING regression +// test. Any future implementation swap must keep this passing. +TEST_CASE("ats_base64 InkAPITest fixture", "[ats_base64][fixture]") +{ + const char *url = "http://www.example.com/foo?fie= \"#%<>[]\\^`{}~&bar={test}&fum=Apache Traffic Server"; + const char *url_b64 = + "aHR0cDovL3d3dy5leGFtcGxlLmNvbS9mb28/ZmllPSAiIyU8PltdXF5ge31+JmJhcj17dGVzdH0mZnVtPUFwYWNoZSBUcmFmZmljIFNlcnZlcg=="; + const auto url_len = std::strlen(url); + const auto url_b64_len = std::strlen(url_b64); + + SECTION("encode produces byte-identical RFC1521 output with '=' padding") + { + std::array buf{}; + std::size_t enc_len = 0; + REQUIRE(ats_base64_encode(url, url_len, buf.data(), buf.size(), &enc_len)); + REQUIRE(enc_len == url_b64_len); + REQUIRE(std::strcmp(buf.data(), url_b64) == 0); + } + + SECTION("decode reproduces the original byte-for-byte") + { + std::array buf{}; + std::size_t dec_len = 0; + REQUIRE(ats_base64_decode(url_b64, url_b64_len, reinterpret_cast(buf.data()), buf.size(), &dec_len)); + REQUIRE(dec_len == url_len); + REQUIRE(std::strcmp(buf.data(), url) == 0); + } +} + +// The decoder accepts the URL-safe alphabet ('-' for '+', '_' for '/') in +// the same call as standard input. Long enough to exercise the simdutf path. +TEST_CASE("ats_base64_decode accepts URL-safe alphabet", "[ats_base64]") +{ + // Decode the same payload twice, once with standard '+/' and once with + // URL-safe '-_', and require identical output. The 0xfb/0xbf/0xff pattern + // produces '+' and '/' in the encoded form, so the URL-safe substitution + // actually changes bytes. + const std::vector payload = { + 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, + 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, + 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, 0xfb, 0xff, 0xbf, 0xff, 0xfe, 0xbf, + }; + + std::string standard = encode_with_ats(payload); + REQUIRE(standard.size() > 48); // encoded form must cross BASE64_DECODE_SIMD_THRESHOLD + std::string url_safe = standard; + for (auto &c : url_safe) { + if (c == '+') { + c = '-'; + } else if (c == '/') { + c = '_'; + } + } + REQUIRE(url_safe != standard); // payload chosen so the swap actually changes bytes + + std::vector out_std(ats_base64_decode_dstlen(standard.size()) + 1); + std::vector out_url(ats_base64_decode_dstlen(url_safe.size()) + 1); + std::size_t len_std = 0; + std::size_t len_url = 0; + REQUIRE(ats_base64_decode(standard.data(), standard.size(), out_std.data(), out_std.size(), &len_std)); + REQUIRE(ats_base64_decode(url_safe.data(), url_safe.size(), out_url.data(), out_url.size(), &len_url)); + + REQUIRE(len_std == payload.size()); + REQUIRE(len_url == payload.size()); + REQUIRE(std::memcmp(out_std.data(), payload.data(), payload.size()) == 0); + REQUIRE(std::memcmp(out_url.data(), payload.data(), payload.size()) == 0); +} + +// In-place decode (dst == src) must produce the same result as decoding into +// a separate buffer. Used by plugins/experimental/magick. +TEST_CASE("ats_base64_decode supports in-place (dst == src)", "[ats_base64]") +{ + for (std::size_t sz : std::array{1, 16, 24, 47, 48, 200}) { + auto input = make_random_bytes(sz, 0xBADF00D + sz); + std::string encoded = encode_with_ats(input); + const std::size_t enc_size = encoded.size(); + // The in-place buffer must hold both the encoded input AND the trailing + // NUL the decoder writes; encoded.size() is always >= the decoded size, + // so one extra byte for the NUL is enough. + std::string in_place = encoded; + in_place.resize(enc_size + 1); + + std::vector reference(ats_base64_decode_dstlen(encoded.size()) + 1); + std::size_t ref_len = 0; + std::size_t ip_len = 0; + + REQUIRE(ats_base64_decode(encoded.data(), enc_size, reference.data(), reference.size(), &ref_len)); + REQUIRE( + ats_base64_decode(in_place.data(), enc_size, reinterpret_cast(in_place.data()), in_place.size(), &ip_len)); + + CAPTURE(sz); + REQUIRE(ip_len == ref_len); + REQUIRE(std::memcmp(in_place.data(), reference.data(), ref_len) == 0); + } +} + +// A non-alphabet byte mid-input truncates: the decoder must stop at the +// first such byte and return the bytes decoded up to that point. This was +// the documented behavior of the legacy scalar path and the simdutf wrapper +// preserves it by pre-scanning the input. +TEST_CASE("ats_base64_decode truncates at first non-alphabet byte", "[ats_base64]") +{ + // 28 alphabet bytes then a stop byte then more alphabet (we expect the + // tail to be ignored). 28 chars decode to 21 bytes. Length is below the + // simdutf threshold so this exercises the scalar path. + const char *input = "AAAAAAAAAAAAAAAAAAAAAAAAAAAA!!!IGNORED-TAIL-IGNORED-TAIL"; + + std::array out{}; + std::size_t len = 0; + REQUIRE(ats_base64_decode(input, std::strlen(input), out.data(), out.size(), &len)); + REQUIRE(len == 21); + for (std::size_t i = 0; i < len; ++i) { + REQUIRE(out[i] == 0); // 'A' = base64 index 0, so 28 'A's decode to 21 zero bytes + } +} + +// Whitespace in the input should be treated like any other non-alphabet +// byte: the decoder stops at it. This is the property that keeps the SIMD +// and scalar paths aligned regardless of input length, since simdutf would +// otherwise silently skip whitespace and continue. +TEST_CASE("ats_base64_decode stops at ASCII whitespace", "[ats_base64]") +{ + // Construct an input long enough to cross the simdutf threshold so we + // exercise the wrapper's pre-scan, with a tab byte planted mid-buffer. + std::string input; + input.assign(60, 'A'); // 60 'A's -> 45 zero bytes if fully decoded + input[40] = '\t'; // first whitespace at index 40 -> 30 bytes after truncation + + std::array out{}; + std::size_t len = 0; + REQUIRE(ats_base64_decode(input.data(), input.size(), out.data(), out.size(), &len)); + REQUIRE(len == 30); +} + +// 1, 2, and 3 base64 chars decode to 0, 1, and 2 bytes respectively. Previous +// code had OOB reads in this path; this case guards against a regression. +TEST_CASE("ats_base64_decode handles very short alphabet inputs", "[ats_base64]") +{ + std::array out{}; + std::size_t len = 0; + + // 1 alphabet char: encodes nothing meaningful, decoded length is 0. + REQUIRE(ats_base64_decode("A", 1, out.data(), out.size(), &len)); + REQUIRE(len == 0); + + // 2 alphabet chars: decoded length is 1. + REQUIRE(ats_base64_decode("AA", 2, out.data(), out.size(), &len)); + REQUIRE(len == 1); + REQUIRE(out[0] == 0); + + // 3 alphabet chars: decoded length is 2. + REQUIRE(ats_base64_decode("AAA", 3, out.data(), out.size(), &len)); + REQUIRE(len == 2); + REQUIRE(out[0] == 0); + REQUIRE(out[1] == 0); +} diff --git a/tools/benchmark/benchmark_ink_base64.cc b/tools/benchmark/benchmark_ink_base64.cc index bec34456ab8..8e620ee8206 100644 --- a/tools/benchmark/benchmark_ink_base64.cc +++ b/tools/benchmark/benchmark_ink_base64.cc @@ -1,8 +1,18 @@ /** @file - Micro benchmark for ats_base64_encode / ats_base64_decode and the bulk - scalar tolower path used by URL canonicalization. Establishes a baseline - prior to any SIMD work. + Throughput benchmark for ats_base64_encode / ats_base64_decode comparing + the scalar path against the simdutf-backed path. + + Sizes bracket both the scalar↔SIMD crossover thresholds (24 bytes for + encode, 48 bytes for decode) and the typical caller sizes inside ATS + (8-byte SnowflakeID, 20-32 byte HMACs, ~200 byte OCSP DER requests, + larger payloads for ceiling measurements). Correctness is covered by + src/tscore/unit_tests/test_ink_base64.cc, which runs under ctest in + every build. + + Catch::Benchmark::keep_memory is used around each call to prevent the + optimizer from DCE-ing the inlined output buffer writes past the first + observed byte. @section license License @@ -27,13 +37,14 @@ #include #include +#include #include "tscore/ink_base64.h" -#include "tscore/ParseRules.h" #include #include #include +#include #include #include #include @@ -45,7 +56,7 @@ namespace // crossover. // 8B - SnowflakeID (uint64_t) // 16-48B - HMAC-SHA1/SHA256 and crossover region for encode -// 64-128B - crossover region for decode +// 64-96B - crossover region for decode // 200B - typical OCSP DER request (RFC6960 caps at 255B encoded) // 512B / 4KB - stress the inner loop where SIMD wins most constexpr std::array kPayloadSizes{8, 16, 24, 32, 48, 64, 96, 200, 512, 4096}; @@ -73,82 +84,20 @@ encode_with_ats(const std::vector &in) return out; } -std::vector -make_mixed_case_ascii(size_t n, uint64_t seed = 0xABCDEFULL) -{ - std::mt19937_64 rng(seed); - std::vector v(n); - for (size_t i = 0; i < n; ++i) { - // Mix of uppercase, lowercase, and a few non-letter bytes that should - // pass through tolower unchanged. Models a URL/header byte stream. - auto r = static_cast(rng() & 0x3FU); - if (r < 26U) { - v[i] = static_cast('A' + r); - } else if (r < 52U) { - v[i] = static_cast('a' + (r - 26U)); - } else { - static constexpr char kNonAlpha[] = "0123456789-_./:"; - v[i] = kNonAlpha[r % (sizeof(kNonAlpha) - 1U)]; - } - } - return v; -} - -// Equivalent of the static inline memcpy_tolower() in src/proxy/hdrs/URL.cc. -// Reproduced here because that definition has internal linkage and isn't -// reachable from this TU. -inline void -memcpy_tolower_scalar(char *d, const char *s, int n) -{ - while (n--) { - *d = ParseRules::ink_tolower(*s); - ++s; - ++d; - } -} - } // namespace -TEST_CASE("ats_base64 round-trip correctness", "[base64][correctness]") +TEST_CASE("active base64 configuration", "[base64][config]") { - for (size_t sz : kPayloadSizes) { - auto input = make_random_bytes(sz); - auto encoded = encode_with_ats(input); - std::vector decoded(ats_base64_decode_dstlen(encoded.size()) + 1); - size_t dec_len = 0; - REQUIRE(ats_base64_decode(encoded.data(), encoded.size(), decoded.data(), decoded.size(), &dec_len)); - REQUIRE(dec_len == sz); - REQUIRE(std::memcmp(decoded.data(), input.data(), sz) == 0); - } -} - -// Lock the same byte-exact fixture used by InkAPITest's SDK_API_ENCODING -// regression test. Any future implementation swap must keep this passing. -TEST_CASE("ats_base64 InkAPITest fixture", "[base64][correctness][fixture]") -{ - const char *url = "http://www.example.com/foo?fie= \"#%<>[]\\^`{}~&bar={test}&fum=Apache Traffic Server"; - const char *url_b64 = - "aHR0cDovL3d3dy5leGFtcGxlLmNvbS9mb28/ZmllPSAiIyU8PltdXF5ge31+JmJhcj17dGVzdH0mZnVtPUFwYWNoZSBUcmFmZmljIFNlcnZlcg=="; - const auto url_len = std::strlen(url); - const auto url_b64_len = std::strlen(url_b64); - - SECTION("encode produces byte-identical RFC1521 output with '=' padding") - { - std::array buf{}; - size_t enc_len = 0; - REQUIRE(ats_base64_encode(url, url_len, buf.data(), buf.size(), &enc_len)); - REQUIRE(enc_len == url_b64_len); - REQUIRE(std::strcmp(buf.data(), url_b64) == 0); - } - - SECTION("decode reproduces the original byte-for-byte") - { - std::array buf{}; - size_t dec_len = 0; - REQUIRE(ats_base64_decode(url_b64, url_b64_len, reinterpret_cast(buf.data()), buf.size(), &dec_len)); - REQUIRE(dec_len == url_len); - REQUIRE(std::strcmp(buf.data(), url) == 0); - } + // Print whether simdutf is wired in so the benchmark output makes the + // selected configuration obvious. + std::cout << "ats_base64 compiled with: "; +#if TS_USE_SIMDUTF + std::cout << "simdutf hybrid (scalar <= 24/48B, simdutf above)"; +#else + std::cout << "scalar only"; +#endif + std::cout << '\n'; + SUCCEED(); } TEST_CASE("ats_base64_encode throughput", "[bench][base64][encode]") @@ -162,7 +111,7 @@ TEST_CASE("ats_base64_encode throughput", "[bench][base64][encode]") { size_t out_len = 0; bool ok = ats_base64_encode(input.data(), input.size(), output.data(), output.size(), &out_len); - // Return a value that depends on the work to prevent DCE. + Catch::Benchmark::keep_memory(output.data()); return ok ? out_len : size_t{0}; }; } @@ -181,25 +130,8 @@ TEST_CASE("ats_base64_decode throughput", "[bench][base64][decode]") { size_t out_len = 0; bool ok = ats_base64_decode(encoded.data(), encoded.size(), output.data(), output.size(), &out_len); + Catch::Benchmark::keep_memory(output.data()); return ok ? out_len : size_t{0}; }; } } - -TEST_CASE("memcpy_tolower throughput", "[bench][tolower]") -{ - // Sizes chosen to model URL paths / header names / cache-key segments. - constexpr std::array kTolowerSizes{16, 64, 256, 1024}; - - for (size_t sz : kTolowerSizes) { - auto input = make_mixed_case_ascii(sz); - std::vector output(sz); - - std::string name = "tolower " + std::to_string(sz) + "B"; - BENCHMARK(name.c_str()) - { - memcpy_tolower_scalar(output.data(), input.data(), static_cast(sz)); - return output[0]; - }; - } -}