From 09ddbf4373d9db2b77575bccb040afefe060f081 Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Thu, 9 Apr 2026 23:30:27 +0000 Subject: [PATCH 01/11] Tested and benchmarked OneDNN BRGeMM integration against dev branch --- BUILD.bazel | 30 ++- MODULE.bazel | 12 ++ bazel/onednn.BUILD | 228 ++++++++++++++++++++ ops/bench_matmul.cc | 6 +- ops/brgemm-inl.h | 492 ++++++++++++++++++++++++++++++++++++++++++++ ops/brgemm.h | 297 ++++++++++++++++++++++++++ ops/matmul-inl.h | 44 ++++ ops/matmul.h | 4 + util/zones.cc | 2 + util/zones.h | 1 + 10 files changed, 1110 insertions(+), 6 deletions(-) create mode 100644 bazel/onednn.BUILD create mode 100644 ops/brgemm-inl.h create mode 100644 ops/brgemm.h diff --git a/BUILD.bazel b/BUILD.bazel index deb376bf..704bd391 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -313,7 +313,14 @@ test_suite( cc_library( name = "matmul_env", srcs = ["ops/matmul.cc"], - hdrs = ["ops/matmul.h"], + hdrs = [ + "ops/brgemm.h", + "ops/matmul.h", + ], + defines = select({ + "@platforms//cpu:x86_64": ["GEMMA_ONEDNN=1", "DNNL_EXPERIMENTAL_UKERNEL"], + "//conditions:default": [], + }), deps = [ ":allocator", ":basics", @@ -324,14 +331,20 @@ cc_library( "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:profiler", - ], + ] + select({ + "@platforms//cpu:x86_64": ["@onednn//:onednn"], + "//conditions:default": [], + }), ) cc_library( name = "matmul", # allow depending only on this target, without also matmul_env. hdrs = ["ops/matmul.h"], - textual_hdrs = ["ops/matmul-inl.h"], + textual_hdrs = [ + "ops/brgemm-inl.h", + "ops/matmul-inl.h", + ], deps = [ ":allocator", ":basics", @@ -345,7 +358,10 @@ cc_library( "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:profiler", - ], + ] + select({ + "@platforms//cpu:x86_64": ["@onednn//:onednn"], + "//conditions:default": [], + }), ) cc_library( @@ -362,6 +378,7 @@ cc_library( "ops/matmul_static.h", ], textual_hdrs = [ + "ops/brgemm-inl.h", "ops/matmul_static-inl.h", "ops/matmul-inl.h", ], @@ -378,7 +395,10 @@ cc_library( "@highway//:hwy", "@highway//:profiler", "@highway//:timer", - ], + ] + select({ + "@platforms//cpu:x86_64": ["@onednn//:onednn"], + "//conditions:default": [], + }), ) cc_library( diff --git a/MODULE.bazel b/MODULE.bazel index 0dea7752..a017e82f 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -8,6 +8,7 @@ bazel_dep(name = "bazel_skylib", version = "1.8.1") bazel_dep(name = "googletest", version = "1.17.0") bazel_dep(name = "highway", version = "1.1.0") bazel_dep(name = "nlohmann_json", version = "3.11.3") +bazel_dep(name = "onetbb", version = "2021.13.0") bazel_dep(name = "protobuf", version = "33.4") bazel_dep(name = "platforms", version = "1.0.0") bazel_dep(name = "pybind11_bazel", version = "2.13.6") @@ -25,6 +26,17 @@ git_override( http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +# OneDNN v3.11 for BRGeMM micro-kernel support (optional, x86-64 only). +http_archive( + name = "onednn", + build_file = "@//bazel:onednn.BUILD", + sha256 = "04df98b18300daf6c3aa7cc2d5e7ce8a8f430fed1787151daed0254d8dd4e64e", + strip_prefix = "oneDNN-3.11", + urls = [ + "https://github.com/uxlfoundation/oneDNN/archive/refs/tags/v3.11.tar.gz", + ], +) + http_archive( name = "com_google_absl_py", sha256 = "8a3d0830e4eb4f66c4fa907c06edf6ce1c719ced811a12e26d9d3162f8471758", diff --git a/bazel/onednn.BUILD b/bazel/onednn.BUILD new file mode 100644 index 00000000..b9b16647 --- /dev/null +++ b/bazel/onednn.BUILD @@ -0,0 +1,228 @@ +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") + +exports_files(["LICENSE"]) + +expand_template( + name = "dnnl_config_h", + out = "include/oneapi/dnnl/dnnl_config.h", + substitutions = { + "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "#define DNNL_EXPERIMENTAL_UKERNEL 1", + "#cmakedefine DNNL_SAFE_RBP": "#undef DNNL_SAFE_RBP", + "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_TBB", + "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_TBB", + "#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#define DNNL_DISABLE_GPU_REF_KERNELS", + "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", + "#cmakedefine DNNL_GPU_VENDOR DNNL_VENDOR_${DNNL_GPU_VENDOR}": "#define DNNL_GPU_VENDOR DNNL_VENDOR_NONE", + "#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "#undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE", + "#cmakedefine DNNL_WITH_SYCL": "#undef DNNL_WITH_SYCL", + "#cmakedefine DNNL_WITH_LEVEL_ZERO": "#undef DNNL_WITH_LEVEL_ZERO", + "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA", + "#cmakedefine DNNL_SYCL_GENERIC": "#undef DNNL_SYCL_GENERIC", + "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP", + "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", + "#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH", + "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE", + "#cmakedefine DNNL_EXPERIMENTAL_LOGGING": "#undef DNNL_EXPERIMENTAL_LOGGING", + "#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING", + "#cmakedefine DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER": "#undef DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER", + "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", + "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", + "#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0", + "#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1", + "#cmakedefine01 BUILD_BATCH_NORMALIZATION": "#define BUILD_BATCH_NORMALIZATION 0", + "#cmakedefine01 BUILD_BINARY": "#define BUILD_BINARY 0", + "#cmakedefine01 BUILD_CONCAT": "#define BUILD_CONCAT 0", + "#cmakedefine01 BUILD_CONVOLUTION": "#define BUILD_CONVOLUTION 0", + "#cmakedefine01 BUILD_DECONVOLUTION": "#define BUILD_DECONVOLUTION 0", + "#cmakedefine01 BUILD_ELTWISE": "#define BUILD_ELTWISE 0", + "#cmakedefine01 BUILD_GEMM_KERNELS_ALL": "#define BUILD_GEMM_KERNELS_ALL 1", + "#cmakedefine01 BUILD_GEMM_KERNELS_NONE": "#define BUILD_GEMM_KERNELS_NONE 0", + "#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 1", + "#cmakedefine01 BUILD_GEMM_AVX2": "#define BUILD_GEMM_AVX2 1", + "#cmakedefine01 BUILD_GEMM_AVX512": "#define BUILD_GEMM_AVX512 1", + "#cmakedefine01 BUILD_GROUP_NORMALIZATION": "#define BUILD_GROUP_NORMALIZATION 1", + "#cmakedefine01 BUILD_INNER_PRODUCT": "#define BUILD_INNER_PRODUCT 0", + "#cmakedefine01 BUILD_LAYER_NORMALIZATION": "#define BUILD_LAYER_NORMALIZATION 0", + "#cmakedefine01 BUILD_LRN": "#define BUILD_LRN 0", + "#cmakedefine01 BUILD_MATMUL": "#define BUILD_MATMUL 0", + "#cmakedefine01 BUILD_POOLING": "#define BUILD_POOLING 0", + "#cmakedefine01 BUILD_PRELU": "#define BUILD_PRELU 0", + "#cmakedefine01 BUILD_REDUCTION": "#define BUILD_REDUCTION 0", + "#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0", + "#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0", + "#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0", + "#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0", + "#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0", + "#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0", + "#cmakedefine01 BUILD_PRIMITIVE_CPU_ISA_ALL": "#define BUILD_PRIMITIVE_CPU_ISA_ALL 1", + "#cmakedefine01 BUILD_SSE41": "#define BUILD_SSE41 0", + "#cmakedefine01 BUILD_AVX2": "#define BUILD_AVX2 0", + "#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0", + "#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0", + "#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 0", + "#cmakedefine01 BUILD_XE2": "#define BUILD_XE2 0", + "#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0", + "#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0", + "#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0", + "#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0", + "#cmakedefine01 BUILD_SDPA": "#define BUILD_SDPA 1", + "#cmakedefine01 BUILD_XE3": "#define BUILD_XE3 0", + }, + template = "include/oneapi/dnnl/dnnl_config.h.in", +) + +expand_template( + name = "dnnl_version_h", + out = "include/oneapi/dnnl/dnnl_version.h", + substitutions = { + "@DNNL_VERSION_MAJOR@": "3", + "@DNNL_VERSION_MINOR@": "11", + "@DNNL_VERSION_PATCH@": "0", + }, + template = "include/oneapi/dnnl/dnnl_version.h.in", +) + +expand_template( + name = "dnnl_version_hash_h", + out = "include/oneapi/dnnl/dnnl_version_hash.h", + substitutions = { + "@DNNL_VERSION_HASH@": "fc6151651a4577beae5ffac5a4132e75d39e1409", + }, + template = "include/oneapi/dnnl/dnnl_version_hash.h.in", +) + +cc_library( + name = "onednn_autogen", + srcs = glob(["src/cpu/x64/gemm/**/*_kern_autogen*.cpp"]), + copts = [ + "-O1", + "-U_FORTIFY_SOURCE", + "-fexceptions", + "-UUSE_MKL", + "-UUSE_CBLAS", + "-DDNNL_ENABLE_MAX_CPU_ISA", + "-DDNNL_ENABLE_ITT_TASKS", + "-DDNNL_ENABLE_GRAPH_DUMP", + "-DDNNL_EXPERIMENTAL_UKERNEL", + ], + includes = [ + "include", + "src", + "src/common", + "src/cpu", + "src/cpu/gemm", + "src/graph", + "third_party", + "third_party/ittnotify", + "third_party/xbyak", + ], + textual_hdrs = glob([ + "include/**/*", + "src/common/*.hpp", + "src/cpu/*.hpp", + "src/cpu/**/*.hpp", + "src/cpu/jit_utils/**/*.hpp", + "src/graph/interface/*.hpp", + "src/graph/backend/*.hpp", + "src/graph/backend/dnnl/*.hpp", + "src/graph/backend/dnnl/executables/*.hpp", + "src/graph/backend/fake/*.hpp", + "src/graph/backend/dnnl/passes/*.hpp", + "src/graph/backend/dnnl/patterns/*.hpp", + "src/graph/backend/dnnl/kernels/*.hpp", + "src/graph/utils/*.hpp", + "src/graph/utils/pm/*.hpp", + "third_party/ittnotify/**/*.h", + "third_party/spdlog/**/*.h", + "third_party/xbyak/*.h", + ]) + [ + ":dnnl_config_h", + ":dnnl_version_h", + ":dnnl_version_hash_h", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "onednn", + srcs = glob( + [ + "src/common/*.cpp", + "src/cpu/*.cpp", + "src/cpu/**/*.cpp", + "src/cpu/jit_utils/**/*.cpp", + "src/cpu/x64/**/*.cpp", + "src/graph/interface/*.cpp", + "src/graph/backend/*.cpp", + "src/graph/backend/dnnl/*.cpp", + "src/graph/backend/dnnl/executables/*.cpp", + "src/graph/backend/fake/*.cpp", + "src/graph/backend/dnnl/passes/*.cpp", + "src/graph/backend/dnnl/patterns/*.cpp", + "src/graph/backend/dnnl/kernels/*.cpp", + "src/graph/utils/*.cpp", + "src/graph/utils/pm/*.cpp", + "third_party/ittnotify/*.c", + ], + exclude = [ + "src/cpu/aarch64/**", + "src/cpu/rv64/**", + "src/cpu/ppc64/**", + "src/cpu/s390x/**", + "src/cpu/x64/gemm/**/*_kern_autogen.cpp", + "src/cpu/sycl/**", + ], + ), + copts = [ + "-fexceptions", + "-UUSE_MKL", + "-UUSE_CBLAS", + "-DDNNL_ENABLE_MAX_CPU_ISA", + "-DDNNL_ENABLE_ITT_TASKS", + "-DDNNL_ENABLE_GRAPH_DUMP", + "-DDNNL_EXPERIMENTAL_UKERNEL", + ], + includes = [ + "include", + "src", + "src/common", + "src/cpu", + "src/cpu/gemm", + "src/graph", + "third_party", + "third_party/ittnotify", + "third_party/xbyak", + ], + linkopts = [ + "-lrt", + "-Wl,--allow-multiple-definition", + ], + textual_hdrs = glob([ + "include/**/*", + "src/common/*.hpp", + "src/cpu/*.hpp", + "src/cpu/**/*.hpp", + "src/cpu/jit_utils/**/*.hpp", + "src/graph/interface/*.hpp", + "src/graph/backend/*.hpp", + "src/graph/backend/dnnl/*.hpp", + "src/graph/backend/fake/*.hpp", + "src/graph/backend/dnnl/passes/*.hpp", + "src/graph/backend/dnnl/patterns/*.hpp", + "src/graph/backend/dnnl/kernels/*.hpp", + "src/graph/utils/*.hpp", + "src/graph/utils/pm/*.hpp", + "third_party/ittnotify/**/*.h", + "third_party/spdlog/**/*.h", + "third_party/xbyak/*.h", + ]) + [ + ":dnnl_config_h", + ":dnnl_version_h", + ":dnnl_version_hash_h", + ], + visibility = ["//visibility:public"], + deps = [ + ":onednn_autogen", + "@onetbb//:tbb", + ], +) diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 67c702f5..2e1d2882 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -130,7 +130,11 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { keep += hwy::ConvertScalarTo(C.Row(0)[hwy::Unpredictable1()]); // Only record times after autotuning finished. - if (per_key->autotune.Best()) times.push_back(elapsed); + bool done = per_key->autotune.Best(); +#if GEMMA_ONEDNN + done = done || per_key->brgemm_autotune.Best(); +#endif + if (done) times.push_back(elapsed); } hwy::PreventElision(keep); env.ctx.pools.MaybeStopSpinning(use_spinning); diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h new file mode 100644 index 00000000..ccf33e69 --- /dev/null +++ b/ops/brgemm-inl.h @@ -0,0 +1,492 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// BRGeMM dispatch. Included from matmul-inl.h inside gcpp::HWY_NAMESPACE. + +#if GEMMA_ONEDNN + +static bool MakeBrgemm(dnnl::ukernel::brgemm& brg, int64_t m, int64_t n, + int64_t k, int64_t batch, int64_t lda, int64_t ldb, + int64_t ldc, dnnl::memory::data_type a_dt, + dnnl::memory::data_type b_dt, + dnnl::memory::data_type c_dt, bool add_C) { + try { + brg = dnnl::ukernel::brgemm(m, n, k, batch, lda, ldb, ldc, a_dt, b_dt, + c_dt, true); + if (!brg) return false; + brg.set_add_C(add_C); + if (!brg.finalize()) return false; + brg.generate(); + return true; + } catch (...) { + return false; + } +} + +template +static HWY_NOINLINE void DoMatMul_BRGeMM( + const MatPtrT& A, const MatPtrT& B, RowPtrs C, size_t M, + size_t K, size_t N, float scale, const float* HWY_RESTRICT add, + const BRGeMMConfig& cfg, ThreadingContext& ctx, size_t cluster_idx) { + using dnnl::ukernel::brgemm; + using dnnl::ukernel::pack_type; + using dnnl::ukernel::transform; + + // Level-1 cache: kernels keyed on (M, K, N, config). + const BRGeMMKernelKey kern_key{M, K, N, cfg.M_blk, cfg.N_blk, cfg.K_blk, + cfg.batch_size}; + auto& kern_cache = GetBRGeMMKernelCache(); + auto kern_it = kern_cache.find(kern_key); + + if (kern_it == kern_cache.end()) { + BRGeMMKernelEntry ke; + + ke.K_blk = cfg.K_blk; + ke.N_blk = cfg.N_blk; + ke.M_blk = + static_cast(std::min(static_cast(cfg.M_blk), M)); + + ke.M_tail = M % ke.M_blk; + ke.N_tail = N % ke.N_blk; + ke.K_tail = K % ke.K_blk; + + ke.K_chunks = K / ke.K_blk; + ke.N_full_tiles = N / ke.N_blk; + ke.M_full_tiles = M / ke.M_blk; + ke.N_total_tiles = ke.N_full_tiles + (ke.N_tail ? 1 : 0); + ke.M_total_tiles = ke.M_full_tiles + (ke.M_tail ? 1 : 0); + ke.N_padded = ke.N_total_tiles * ke.N_blk; + + if (ke.M_total_tiles == 0 || ke.N_total_tiles == 0 || + (ke.K_chunks == 0 && ke.K_tail == 0)) { + return; + } + + ke.K_super_size = std::min(cfg.batch_size, ke.K_chunks); + ke.K_super_blocks = (ke.K_chunks > 0) ? ke.K_chunks / ke.K_super_size : 0; + ke.K_super_rem = (ke.K_chunks > 0) ? ke.K_chunks % ke.K_super_size : 0; + ke.batch_full = ke.K_super_size; + ke.batch_rem = ke.K_super_rem; + + const auto a_dt = dnnl::memory::data_type::bf16; + const auto b_dt = dnnl::memory::data_type::bf16; + const auto c_dt = dnnl::memory::data_type::f32; + ke.a_dt_size = dnnl::memory::data_type_size(a_dt); + ke.b_dt_size = dnnl::memory::data_type_size(b_dt); + + const auto pack = brgemm::get_B_pack_type(a_dt, b_dt); + if (pack == pack_type::undef) return; + ke.need_pack = (pack != pack_type::no_trans); + + ke.lda = A.Stride(); + ke.ldb_orig = B.Stride(); + + ke.m_sizes[0] = ke.M_blk; + ke.m_sizes[1] = ke.M_tail ? ke.M_tail : ke.M_blk; + ke.n_sizes[0] = ke.N_blk; + ke.n_sizes[1] = ke.N_tail ? ke.N_tail : ke.N_blk; + const int64_t ldb_for[2] = {ke.N_blk, ke.N_tail ? ke.N_tail : ke.N_blk}; + const int64_t ldc_for[2] = {ke.N_blk, ke.N_tail ? ke.N_tail : ke.N_blk}; + + // Create brgemm kernels for each (M-tile, N-tile) variant. + size_t max_sp = 0; + for (int mi = 0; mi < 2; ++mi) { + for (int ni = 0; ni < 2; ++ni) { + if (mi == 1 && ke.M_tail == 0) continue; + if (ni == 1 && ke.N_tail == 0) continue; + if (mi == 0 && ke.M_full_tiles == 0) continue; + if (ni == 0 && ke.N_full_tiles == 0) continue; + + const int64_t ms = ke.m_sizes[mi]; + const int64_t ns = ke.n_sizes[ni]; + + if (ke.K_chunks > 0) { + if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns, ke.K_blk, + ke.K_super_size, ke.lda, ldb_for[ni], ldc_for[ni], + a_dt, b_dt, c_dt, false)) { + return; + } + max_sp = std::max(max_sp, + ke.brg_first_all[mi][ni].get_scratchpad_size()); + } + if (ke.K_super_blocks > 1) { + if (!MakeBrgemm(ke.brg_full[mi][ni], ms, ns, ke.K_blk, + ke.batch_full, ke.lda, ldb_for[ni], ldc_for[ni], + a_dt, b_dt, c_dt, true)) { + return; + } + max_sp = + std::max(max_sp, ke.brg_full[mi][ni].get_scratchpad_size()); + } + if (ke.K_super_rem > 0) { + const bool rem_is_first = (ke.K_super_blocks == 0); + auto& target = rem_is_first ? ke.brg_first_rem[mi][ni] + : ke.brg_rem[mi][ni]; + if (!MakeBrgemm(target, ms, ns, ke.K_blk, ke.batch_rem, ke.lda, + ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, + !rem_is_first)) { + return; + } + max_sp = std::max(max_sp, target.get_scratchpad_size()); + } + if (ke.K_tail > 0) { + const bool add_c = (ke.K_chunks > 0); + if (!MakeBrgemm(ke.brg_ktail[mi][ni], ms, ns, ke.K_tail, 1, ke.lda, + ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, + add_c)) { + return; + } + max_sp = + std::max(max_sp, ke.brg_ktail[mi][ni].get_scratchpad_size()); + } + } + } + ke.scratchpad_size = max_sp + 64; + + // Create B-packing transforms. + if (ke.need_pack) { + for (int ni = 0; ni < 2; ++ni) { + if (ni == 1 && ke.N_tail == 0) continue; + if (ni == 0 && ke.N_full_tiles == 0) continue; + + const int64_t ns = ke.n_sizes[ni]; + if (ke.K_chunks > 0) { + const int64_t K_full = ke.K_chunks * ke.K_blk; + try { + ke.pack_B[ni] = transform(K_full, ns, pack_type::trans, + ke.ldb_orig, ldb_for[ni], b_dt, b_dt); + if (!ke.pack_B[ni]) return; + ke.pack_B[ni].generate(); + ke.blocked_B_size[ni] = ldb_for[ni] * ke.K_blk * ke.b_dt_size; + } catch (...) { + return; + } + } + if (ke.K_tail > 0) { + try { + ke.pack_B_ktail[ni] = transform( + ke.K_tail, ns, pack_type::trans, ke.ldb_orig, ldb_for[ni], + b_dt, b_dt); + if (!ke.pack_B_ktail[ni]) return; + ke.pack_B_ktail[ni].generate(); + ke.blocked_B_ktail_size[ni] = + ldb_for[ni] * ke.K_tail * ke.b_dt_size; + } catch (...) { + return; + } + } + } + } + + // Precompute A/B offset tables for each K-super-block. + for (int ni = 0; ni < 2; ++ni) { + if (ni == 1 && ke.N_tail == 0) continue; + if (ni == 0 && ke.N_full_tiles == 0) continue; + const int64_t cur_n = ke.n_sizes[ni]; + + if (ke.K_chunks > 0) { + ke.offsets_first_all[ni].resize(ke.K_super_size); + for (int64_t i = 0; i < ke.K_super_size; ++i) { + const int64_t a_off = + i * ke.K_blk * static_cast(ke.a_dt_size); + const int64_t b_off = + ke.need_pack + ? i * static_cast(ke.blocked_B_size[ni]) + : i * cur_n * ke.K_blk * static_cast(ke.b_dt_size); + ke.offsets_first_all[ni][i] = {a_off, b_off}; + } + } + + if (ke.K_super_blocks > 1) { + ke.offsets_full[ni].resize(ke.K_super_blocks - 1); + for (int64_t ks = 1; ks < ke.K_super_blocks; ++ks) { + auto& tbl = ke.offsets_full[ni][ks - 1]; + tbl.resize(ke.batch_full); + const int64_t k_start = ks * ke.K_super_size; + for (int64_t i = 0; i < ke.batch_full; ++i) { + const int64_t k_idx = k_start + i; + const int64_t a_off = + k_idx * ke.K_blk * static_cast(ke.a_dt_size); + const int64_t b_off = + ke.need_pack + ? k_idx * static_cast(ke.blocked_B_size[ni]) + : k_idx * cur_n * ke.K_blk * + static_cast(ke.b_dt_size); + tbl[i] = {a_off, b_off}; + } + } + } + + if (ke.K_super_rem > 0) { + const int64_t k_base = ke.K_super_blocks * ke.K_super_size; + auto& rem_tbl = (ke.K_super_blocks == 0) ? ke.offsets_first_rem[ni] + : ke.offsets_rem[ni]; + rem_tbl.resize(ke.K_super_rem); + for (int64_t i = 0; i < ke.K_super_rem; ++i) { + const int64_t k_idx = k_base + i; + const int64_t a_off = + k_idx * ke.K_blk * static_cast(ke.a_dt_size); + const int64_t b_off = + ke.need_pack + ? k_idx * static_cast(ke.blocked_B_size[ni]) + : k_idx * cur_n * ke.K_blk * + static_cast(ke.b_dt_size); + rem_tbl[i] = {a_off, b_off}; + } + } + } + + kern_it = kern_cache.emplace(kern_key, std::move(ke)).first; + } + + BRGeMMKernelEntry& ke = kern_it->second; + if (ke.M_total_tiles == 0 || ke.N_total_tiles == 0) return; + + // Level-2 cache: packed B keyed on (B_ptr, K, N, config). + const uint8_t* A_base = reinterpret_cast(A.Row(0)); + const uint8_t* B_base = reinterpret_cast(B.Row(0)); + + const BRGeMMPackedBKey pb_key{reinterpret_cast(B_base), K, N, + ke.K_blk, ke.N_blk}; + auto& pb_cache = GetBRGeMMPackedBCache(); + auto pb_it = pb_cache.find(pb_key); + + if (pb_it == pb_cache.end()) { + BRGeMMPackedBEntry pe; + pe.B_tile_offset.resize(ke.N_total_tiles, 0); + pe.B_ktail_offset.resize(ke.N_total_tiles, 0); + + if (ke.need_pack) { + size_t total_packed = 0; + for (int64_t nt = 0; nt < ke.N_total_tiles; ++nt) { + const int ni = (nt < ke.N_full_tiles) ? 0 : 1; + pe.B_tile_offset[nt] = total_packed; + if (ke.K_chunks > 0) + total_packed += ke.blocked_B_size[ni] * ke.K_chunks; + pe.B_ktail_offset[nt] = total_packed; + if (ke.K_tail > 0) total_packed += ke.blocked_B_ktail_size[ni]; + } + + pe.B_packed_buf.Resize(total_packed); + uint8_t* B_packed = pe.B_packed_buf.data(); + if (!B_packed) return; + + for (int64_t nt = 0; nt < ke.N_total_tiles; ++nt) { + const int ni = (nt < ke.N_full_tiles) ? 0 : 1; + const int64_t b_row = (nt < ke.N_full_tiles) + ? nt * ke.N_blk + : ke.N_full_tiles * ke.N_blk; + const uint8_t* B_in = + B_base + b_row * ke.ldb_orig * ke.b_dt_size; + + try { + if (ke.K_chunks > 0) { + ke.pack_B[ni].execute(const_cast(B_in), + B_packed + pe.B_tile_offset[nt]); + } + if (ke.K_tail > 0) { + const uint8_t* B_in_ktail = + B_in + ke.K_chunks * ke.K_blk * ke.b_dt_size; + ke.pack_B_ktail[ni].execute(const_cast(B_in_ktail), + B_packed + pe.B_ktail_offset[nt]); + } + } catch (...) { + return; + } + } + } + + pb_it = pb_cache.emplace(pb_key, std::move(pe)).first; + } + + const BRGeMMPackedBEntry& pe = pb_it->second; + const uint8_t* B_packed = + ke.need_pack ? pe.B_packed_buf.data() : nullptr; + + std::vector> offsets_ktail(1); + if (ke.K_tail > 0) offsets_ktail[0] = {0, 0}; + + // Execute one (m, n) tile for a given K-super-block. + const auto execute_tile = [&](size_t m_start, size_t n_start, + int64_t k_super, float* temp_C, + uint8_t* scratch) HWY_ATTR { + const int64_t m_tile_idx = m_start / ke.M_blk; + const int64_t n_tile_idx = n_start / ke.N_blk; + const int mi = (m_tile_idx < ke.M_full_tiles) ? 0 : 1; + const int ni = (n_tile_idx < ke.N_full_tiles) ? 0 : 1; + const int64_t cur_m = ke.m_sizes[mi]; + const int64_t cur_n = ke.n_sizes[ni]; + + const size_t real_m = (m_tile_idx < ke.M_full_tiles) + ? m_tile_idx * ke.M_blk + : ke.M_full_tiles * ke.M_blk; + const size_t real_n = (n_tile_idx < ke.N_full_tiles) + ? n_tile_idx * ke.N_blk + : ke.N_full_tiles * ke.N_blk; + + const uint8_t* A_tile = A_base + real_m * ke.lda * ke.a_dt_size; + const void* B_tile = + ke.need_pack + ? static_cast(B_packed + + pe.B_tile_offset[n_tile_idx]) + : static_cast(B_base + + real_n * ke.ldb_orig * ke.b_dt_size); + + float* C_tile_ptr = temp_C; + const int64_t k_total = + ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0); + + if (k_super < ke.K_super_blocks) { + if (k_super == 0) { + ke.brg_first_all[mi][ni].execute(A_tile, const_cast(B_tile), + ke.offsets_first_all[ni], C_tile_ptr, + scratch); + } else { + ke.brg_full[mi][ni].execute(A_tile, const_cast(B_tile), + ke.offsets_full[ni][k_super - 1], + C_tile_ptr, scratch); + } + } else if (ke.K_super_rem > 0 && k_super == ke.K_super_blocks) { + if (ke.K_super_blocks == 0) { + ke.brg_first_rem[mi][ni].execute(A_tile, const_cast(B_tile), + ke.offsets_first_rem[ni], C_tile_ptr, + scratch); + } else { + ke.brg_rem[mi][ni].execute(A_tile, const_cast(B_tile), + ke.offsets_rem[ni], C_tile_ptr, scratch); + } + } + + const bool is_last = (k_total > 0) ? (k_super == k_total - 1) : true; + if (is_last) { + if (ke.K_tail > 0) { + const uint8_t* A_ktail = + A_tile + ke.K_chunks * ke.K_blk * ke.a_dt_size; + const void* B_ktail = + ke.need_pack + ? static_cast(B_packed + + pe.B_ktail_offset[n_tile_idx]) + : static_cast( + B_base + (real_n * ke.ldb_orig + + ke.K_chunks * ke.K_blk) * + ke.b_dt_size); + ke.brg_ktail[mi][ni].execute(A_ktail, const_cast(B_ktail), + offsets_ktail, C_tile_ptr, scratch); + } + + // Scale and copy temp_C to output. + const hn::ScalableTag df; + const auto vscale = hn::Set(df, scale); + const size_t lanes = hn::Lanes(df); + for (int64_t m = 0; m < cur_m; ++m) { + TC* C_row = C.Row(real_m + m) + real_n; + const float* t_row = C_tile_ptr + m * cur_n; + const float* add_row = add ? add + real_n : nullptr; + int64_t n = 0; + if (add_row) { + for (; n + static_cast(lanes) <= cur_n; + n += static_cast(lanes)) { + const auto v = hn::Load(df, t_row + n); + const auto va = hn::Load(df, add_row + n); + const auto result = hn::MulAdd(v, vscale, va); + if constexpr (hwy::IsSame()) { + hn::Store(result, df, reinterpret_cast(C_row) + n); + } else { + const hn::Rebind dc; + hn::Store(hn::DemoteTo(dc, result), dc, C_row + n); + } + } + for (; n < cur_n; ++n) { + float val = t_row[n] * scale + add_row[n]; + C_row[n] = hwy::ConvertScalarTo(val); + } + } else { + for (; n + static_cast(lanes) <= cur_n; + n += static_cast(lanes)) { + const auto v = hn::Load(df, t_row + n); + const auto result = hn::Mul(v, vscale); + if constexpr (hwy::IsSame()) { + hn::Store(result, df, reinterpret_cast(C_row) + n); + } else { + const hn::Rebind dc; + hn::Store(hn::DemoteTo(dc, result), dc, C_row + n); + } + } + for (; n < cur_n; ++n) { + float val = t_row[n] * scale; + C_row[n] = hwy::ConvertScalarTo(val); + } + } + } + } + }; + + // Parallel dispatch: K-super outer, N middle, M inner (keeps B in L2). + const int64_t k_total_supers = + ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0); + const int64_t k_iters = (k_total_supers > 0) ? k_total_supers : 1; + + const size_t num_threads = ctx.pools.MaxWorkersPerCluster(); + const size_t total_n_tiles = ke.N_total_tiles; + const size_t total_m_tiles = ke.M_total_tiles; + const size_t n_tasks = + std::max(size_t{1}, std::min(total_n_tiles, num_threads)); + + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kBRGeMM); + + ParallelForWithinCluster( + n_tasks, ctx, cluster_idx, caller, + [&](uint64_t task_idx, size_t /*worker*/) HWY_ATTR { + const size_t tiles_per_task = total_n_tiles / n_tasks; + const size_t extra = total_n_tiles % n_tasks; + const size_t n_begin = + task_idx * tiles_per_task + + std::min(static_cast(task_idx), extra); + const size_t n_end = + n_begin + tiles_per_task + (task_idx < extra ? 1 : 0); + + auto& tbufs = GetBRGeMMThreadBufs(); + tbufs.MaybeSetHwContext(ke.brg_first_all[0][0]); + uint8_t* sp = tbufs.EnsureScratch(ke.scratchpad_size); + + const size_t n_tiles_in_range = n_end - n_begin; + const size_t total_tc = total_m_tiles * n_tiles_in_range; + float* tc_base = tbufs.EnsureTempC(total_tc); + + for (int64_t ks = 0; ks < k_iters; ++ks) { + size_t n_idx = 0; + for (size_t nt = n_begin; nt < n_end; ++nt) { + const size_t n = nt * ke.N_blk; + for (int64_t mt = 0; mt < static_cast(total_m_tiles); + ++mt) { + const size_t m = mt * ke.M_blk; + float* temp_C = + tc_base + (mt * n_tiles_in_range + n_idx) * + BRGeMMThreadBufs::kMaxTempCSize; + execute_tile(m, n, ks, temp_C, sp); + } + ++n_idx; + } + } + }); + + dnnl::ukernel::brgemm::release_hw_context(); + auto& main_bufs = GetBRGeMMThreadBufs(); + main_bufs.hw_ctx_set = false; + main_bufs.hw_ctx_kernel = nullptr; +} + +#endif // GEMMA_ONEDNN diff --git a/ops/brgemm.h b/ops/brgemm.h new file mode 100644 index 00000000..3183c9ce --- /dev/null +++ b/ops/brgemm.h @@ -0,0 +1,297 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// OneDNN BRGeMM micro-kernel integration for MatMul on Intel AMX/AVX-512. +// Enabled at runtime via GEMMA_USE_ONEDNN_BRGEMM=1. + +#ifndef THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ +#define THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "hwy/base.h" + +#if GEMMA_ONEDNN +#include + +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_ukernel.hpp" +#endif // GEMMA_ONEDNN + +namespace gcpp { + +inline bool UseOneDnnBrgemm() { + static const bool enabled = [] { + const char* env = std::getenv("GEMMA_USE_ONEDNN_BRGEMM"); + return env != nullptr && env[0] == '1' && env[1] == '\0'; + }(); + return enabled; +} + +struct BRGeMMConfig { + int64_t M_blk; + int64_t N_blk; + int64_t K_blk; + int64_t batch_size; + int64_t par_m; +}; + +#if GEMMA_ONEDNN + +// Generates autotuning candidates. Fixed: N_blk=32, K_blk=32 (AMX BF16). +// Tunable: M_blk in {32,64}, batch_size in {16,32,64,128,256}. +inline std::vector BRGeMMCandidates(size_t M, size_t K, + size_t N) { + std::vector out; + static constexpr int64_t kNBlk = 32; + static constexpr int64_t kKBlk = 32; + static constexpr int64_t kMBlkValues[] = {32, 64}; + static constexpr int64_t kBatchValues[] = {16, 32, 64, 128, 256}; + + const int64_t k_chunks = static_cast(K) / kKBlk; + for (int64_t mb : kMBlkValues) { + if (mb > static_cast(M)) continue; + if (kNBlk > static_cast(N)) continue; + for (int64_t bs : kBatchValues) { + const int64_t eff_bs = + (k_chunks > 0) ? std::min(bs, k_chunks) : int64_t{1}; + bool dup = false; + for (const auto& c : out) { + if (c.M_blk == mb && c.batch_size == eff_bs) { + dup = true; + break; + } + } + if (dup) continue; + out.push_back({mb, kNBlk, kKBlk, eff_bs, /*par_m=*/1}); + } + } + if (out.empty()) { + out.push_back({static_cast(std::min(M, size_t{32})), + static_cast(std::min(N, size_t{32})), 32, 1, 1}); + } + return out; +} + +// Hugepage-backed buffer via mmap with MADV_HUGEPAGE for packed-B matrices. +class HugePageBuffer { + public: + HugePageBuffer() = default; + ~HugePageBuffer() { + if (ptr_ && size_) munmap(ptr_, size_); + } + + HugePageBuffer(HugePageBuffer&& o) noexcept + : ptr_(o.ptr_), size_(o.size_) { + o.ptr_ = nullptr; + o.size_ = 0; + } + HugePageBuffer& operator=(HugePageBuffer&& o) noexcept { + if (this != &o) { + if (ptr_ && size_) munmap(ptr_, size_); + ptr_ = o.ptr_; + size_ = o.size_; + o.ptr_ = nullptr; + o.size_ = 0; + } + return *this; + } + + HugePageBuffer(const HugePageBuffer&) = delete; + HugePageBuffer& operator=(const HugePageBuffer&) = delete; + + void Resize(size_t n) { + if (ptr_ && size_) munmap(ptr_, size_); + static constexpr size_t kHugePageSize = 2u << 20; + size_ = (n + kHugePageSize - 1) & ~(kHugePageSize - 1); + ptr_ = static_cast(mmap(nullptr, size_, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); + if (ptr_ == MAP_FAILED) { + ptr_ = nullptr; + size_ = 0; + return; + } + madvise(ptr_, size_, MADV_HUGEPAGE); + for (size_t off = 0; off < size_; off += kHugePageSize) { + static_cast(ptr_)[off] = 0; + } + } + + uint8_t* data() { return ptr_; } + const uint8_t* data() const { return ptr_; } + size_t size() const { return size_; } + + private: + uint8_t* ptr_ = nullptr; + size_t size_ = 0; +}; + +// Kernel cache key: identifies a JIT-compiled kernel set. +struct BRGeMMKernelKey { + size_t M, K, N; + int64_t M_blk, N_blk, K_blk, batch_size; + bool operator==(const BRGeMMKernelKey& o) const { + return M == o.M && K == o.K && N == o.N && M_blk == o.M_blk && + N_blk == o.N_blk && K_blk == o.K_blk && batch_size == o.batch_size; + } +}; + +struct BRGeMMKernelKeyHash { + size_t operator()(const BRGeMMKernelKey& k) const { + size_t h = 14695981039346656037ULL; + h = (h ^ k.M) * 1099511628211ULL; + h = (h ^ k.K) * 1099511628211ULL; + h = (h ^ k.N) * 1099511628211ULL; + h = (h ^ static_cast(k.M_blk)) * 1099511628211ULL; + h = (h ^ static_cast(k.N_blk)) * 1099511628211ULL; + h = (h ^ static_cast(k.K_blk)) * 1099511628211ULL; + h = (h ^ static_cast(k.batch_size)) * 1099511628211ULL; + return h; + } +}; + +// Cached JIT-compiled kernels with precomputed tile parameters and offsets. +struct BRGeMMKernelEntry { + int64_t M_blk, N_blk, K_blk; + int64_t M_tail, N_tail, K_tail; + int64_t K_chunks; + int64_t M_full_tiles, N_full_tiles; + int64_t M_total_tiles, N_total_tiles; + int64_t K_super_size, K_super_blocks; + int64_t K_super_rem; + int64_t batch_full, batch_rem; + int64_t m_sizes[2], n_sizes[2]; + int64_t lda; + int64_t ldb_orig; + bool need_pack; + size_t a_dt_size, b_dt_size; + size_t N_padded; + + // Kernels indexed by [m_tail_flag][n_tail_flag]. + dnnl::ukernel::brgemm brg_first_all[2][2]; + dnnl::ukernel::brgemm brg_full[2][2]; + dnnl::ukernel::brgemm brg_ktail[2][2]; + dnnl::ukernel::brgemm brg_first_rem[2][2]; + dnnl::ukernel::brgemm brg_rem[2][2]; + + // B-packing transforms indexed by n_tail_flag. + dnnl::ukernel::transform pack_B[2], pack_B_ktail[2]; + size_t blocked_B_size[2] = {0, 0}; + size_t blocked_B_ktail_size[2] = {0, 0}; + + size_t scratchpad_size = 0; + + using OffsetVec = + std::vector>; + OffsetVec offsets_first_all[2]; + std::vector offsets_full[2]; + OffsetVec offsets_first_rem[2]; + OffsetVec offsets_rem[2]; +}; + +// Packed-B cache key. +struct BRGeMMPackedBKey { + uintptr_t B_ptr; + size_t K, N; + int64_t K_blk, N_blk; + bool operator==(const BRGeMMPackedBKey& o) const { + return B_ptr == o.B_ptr && K == o.K && N == o.N && K_blk == o.K_blk && + N_blk == o.N_blk; + } +}; + +struct BRGeMMPackedBKeyHash { + size_t operator()(const BRGeMMPackedBKey& k) const { + size_t h = 14695981039346656037ULL; + h = (h ^ k.B_ptr) * 1099511628211ULL; + h = (h ^ k.K) * 1099511628211ULL; + h = (h ^ k.N) * 1099511628211ULL; + h = (h ^ static_cast(k.K_blk)) * 1099511628211ULL; + h = (h ^ static_cast(k.N_blk)) * 1099511628211ULL; + return h; + } +}; + +struct BRGeMMPackedBEntry { + HugePageBuffer B_packed_buf; + std::vector B_tile_offset; + std::vector B_ktail_offset; +}; + +// Thread-local buffers for BRGeMM parallel dispatch. +struct BRGeMMThreadBufs { + static constexpr size_t kMaxTempCSize = 64 * 64; + + std::vector scratch; + std::vector tc_storage; + bool hw_ctx_set = false; + const void* hw_ctx_kernel = nullptr; + + uint8_t* EnsureScratch(size_t size) { + if (scratch.size() < size + 64) scratch.resize(size + 64); + return scratch.data() + + (64 - (reinterpret_cast(scratch.data()) % 64)); + } + + float* EnsureTempC(size_t n_tiles) { + const size_t need = n_tiles * kMaxTempCSize * sizeof(float) + 64; + if (tc_storage.size() < need) tc_storage.resize(need); + return reinterpret_cast( + (reinterpret_cast(tc_storage.data()) + 63) & + ~uintptr_t{63}); + } + + void MaybeSetHwContext(const dnnl::ukernel::brgemm& brg) { + const void* brg_ptr = &brg; + if (!hw_ctx_set || hw_ctx_kernel != brg_ptr) { + brg.set_hw_context(); + hw_ctx_set = true; + hw_ctx_kernel = brg_ptr; + } + } +}; + +inline BRGeMMThreadBufs& GetBRGeMMThreadBufs() { + static thread_local BRGeMMThreadBufs bufs; + return bufs; +} + +// Singleton caches. Thread-safety: MatMul is not called concurrently per env. +inline auto& GetBRGeMMKernelCache() { + static std::unordered_map + cache; + return cache; +} + +inline auto& GetBRGeMMPackedBCache() { + static std::unordered_map + cache; + return cache; +} + +#endif // GEMMA_ONEDNN + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 4b217a15..a15a2fe3 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -47,6 +47,10 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +#if GEMMA_ONEDNN +#include "ops/brgemm-inl.h" // DoMatMul_BRGeMM +#endif // GEMMA_ONEDNN + // Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register. template > static hn::VFromD FastPromoteOddTo(DF df, hn::VFromD vbf) { @@ -1077,6 +1081,46 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMPerKey& per_key = MMImpl::FindOrAddPerKey( M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]); +#if GEMMA_ONEDNN + // BRGeMM path for BF16×BF16 on Intel AMX/AVX-512. + // Requires M,N,K >= 32 and K % 32 == 0 (AMX tile constraint). + if constexpr (IsBF16() && IsBF16()) { + if (UseOneDnnBrgemm() && M >= 32 && N >= 32 && K >= 32 && (K % 32) == 0) { + const float scale = A.Scale() * B.Scale(); + MMAutoTune& brg_tuner = per_key.brgemm_autotune; + + if (HWY_LIKELY(brg_tuner.Best())) { + DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, *brg_tuner.Best(), + env.ctx, cluster_idx); + return &per_key; + } + + if (HWY_UNLIKELY(!brg_tuner.HasCandidates())) { + brg_tuner.SetCandidates(BRGeMMCandidates(M, K, N)); + } + + const BRGeMMConfig& cfg = brg_tuner.NextConfig(); + const uint64_t t0 = hwy::timer::Start(); + DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, cfg, env.ctx, + cluster_idx); + const uint64_t t1 = + env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + brg_tuner.NotifyTicks(t1 - t0); + + if (HWY_UNLIKELY(env.print_best && brg_tuner.Best())) { + const BRGeMMConfig& best = *brg_tuner.Best(); + fprintf(stderr, + "BRGeMM best: %zux%zux%zu M_blk=%ld N_blk=%ld K_blk=%ld " + "batch=%ld\n", + M, K, N, static_cast(best.M_blk), + static_cast(best.N_blk), static_cast(best.K_blk), + static_cast(best.batch_size)); + } + return &per_key; + } + } // if constexpr BF16/float +#endif // GEMMA_ONEDNN + // (Also auto-tunes, hence outside the timed section to prevent interference.) const StridedViewBF A_view = MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options); diff --git a/ops/matmul.h b/ops/matmul.h index 0f3d2866..c03ec53f 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -32,6 +32,7 @@ #include "hwy/base.h" #include "hwy/bit_set.h" #include "hwy/profiler.h" +#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN // IWYU pragma: end_exports namespace gcpp { @@ -639,6 +640,9 @@ class MMKeys { struct MMPerKey { MMAutoTune autotune; MMAutoTune autotune_par_a; +#if GEMMA_ONEDNN + MMAutoTune brgemm_autotune; +#endif // GEMMA_ONEDNN }; // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive diff --git a/util/zones.cc b/util/zones.cc index aec4bbd0..b552bb17 100644 --- a/util/zones.cc +++ b/util/zones.cc @@ -135,6 +135,8 @@ const char* CallerName(Callers caller) { return "Att.DotSoftmaxWeightedSum"; case Callers::kBlobWriter: return "BlobWriter"; + case Callers::kBRGeMM: + return "BRGeMM"; case Callers::kCompress: return "Compress"; case Callers::kFixupWeights: diff --git a/util/zones.h b/util/zones.h index 64b859d2..ba3d5a9b 100644 --- a/util/zones.h +++ b/util/zones.h @@ -81,6 +81,7 @@ enum class Callers { // Keep sorted kAttComputeQKV, kAttDotSoftmaxWeightedSum, kBlobWriter, + kBRGeMM, kCompress, kFixupWeights, kFlashAttention, From 1308355ff021edf9bdf5aebce314543b74c65627 Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Thu, 23 Apr 2026 19:23:55 +0000 Subject: [PATCH 02/11] fixing the copyright info --- ops/brgemm-inl.h | 2 +- ops/brgemm.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h index ccf33e69..266191cc 100644 --- a/ops/brgemm-inl.h +++ b/ops/brgemm-inl.h @@ -1,4 +1,4 @@ -// Copyright 2024 Google LLC +// Copyright 2026 DeepMind Technologies Limited. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ops/brgemm.h b/ops/brgemm.h index 3183c9ce..1258a3d6 100644 --- a/ops/brgemm.h +++ b/ops/brgemm.h @@ -1,4 +1,4 @@ -// Copyright 2024 Google LLC +// Copyright 2026 DeepMind Technologies Limited. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); From 656444ff25c6cc570e5e72a6dcd2a282d75e77f0 Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Thu, 23 Apr 2026 20:32:48 +0000 Subject: [PATCH 03/11] Removing OneTBB dependency --- MODULE.bazel | 1 - bazel/onednn.BUILD | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index a017e82f..e44c370b 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -8,7 +8,6 @@ bazel_dep(name = "bazel_skylib", version = "1.8.1") bazel_dep(name = "googletest", version = "1.17.0") bazel_dep(name = "highway", version = "1.1.0") bazel_dep(name = "nlohmann_json", version = "3.11.3") -bazel_dep(name = "onetbb", version = "2021.13.0") bazel_dep(name = "protobuf", version = "33.4") bazel_dep(name = "platforms", version = "1.0.0") bazel_dep(name = "pybind11_bazel", version = "2.13.6") diff --git a/bazel/onednn.BUILD b/bazel/onednn.BUILD index b9b16647..0cbd436d 100644 --- a/bazel/onednn.BUILD +++ b/bazel/onednn.BUILD @@ -8,8 +8,8 @@ expand_template( substitutions = { "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "#define DNNL_EXPERIMENTAL_UKERNEL 1", "#cmakedefine DNNL_SAFE_RBP": "#undef DNNL_SAFE_RBP", - "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_TBB", - "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_TBB", + "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_SEQ", + "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_SEQ", "#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#define DNNL_DISABLE_GPU_REF_KERNELS", "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", "#cmakedefine DNNL_GPU_VENDOR DNNL_VENDOR_${DNNL_GPU_VENDOR}": "#define DNNL_GPU_VENDOR DNNL_VENDOR_NONE", @@ -223,6 +223,5 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":onednn_autogen", - "@onetbb//:tbb", ], ) From f8527a15f115050487efcff86796568c8cfc791e Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Thu, 23 Apr 2026 21:39:30 +0000 Subject: [PATCH 04/11] Fixed the compile time flag to designate BRGEMM path --- BUILD.bazel | 15 +++++++++++---- ops/bench_matmul.cc | 2 +- ops/brgemm-inl.h | 4 ++-- ops/brgemm.h | 19 +++++-------------- ops/matmul-inl.h | 10 +++++----- ops/matmul.h | 6 +++--- 6 files changed, 27 insertions(+), 29 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 704bd391..d61492b0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -27,6 +27,13 @@ exports_files([ ".github/workflows/build.yml", ]) +# To enable OneDNN BRGeMM support, build with: +# bazel build --define gemma_onednn_brgemm=1 ... +config_setting( + name = "gemma_onednn_brgemm", + define_values = {"gemma_onednn_brgemm": "1"}, +) + cc_library( name = "basics", srcs = ["util/basics.cc"], @@ -318,7 +325,7 @@ cc_library( "ops/matmul.h", ], defines = select({ - "@platforms//cpu:x86_64": ["GEMMA_ONEDNN=1", "DNNL_EXPERIMENTAL_UKERNEL"], + ":gemma_onednn_brgemm": ["GEMMA_ONEDNN_BRGEMM=1", "DNNL_EXPERIMENTAL_UKERNEL"], "//conditions:default": [], }), deps = [ @@ -332,7 +339,7 @@ cc_library( "@highway//:nanobenchmark", "@highway//:profiler", ] + select({ - "@platforms//cpu:x86_64": ["@onednn//:onednn"], + ":gemma_onednn_brgemm": ["@onednn//:onednn"], "//conditions:default": [], }), ) @@ -359,7 +366,7 @@ cc_library( "@highway//:nanobenchmark", "@highway//:profiler", ] + select({ - "@platforms//cpu:x86_64": ["@onednn//:onednn"], + ":gemma_onednn_brgemm": ["@onednn//:onednn"], "//conditions:default": [], }), ) @@ -396,7 +403,7 @@ cc_library( "@highway//:profiler", "@highway//:timer", ] + select({ - "@platforms//cpu:x86_64": ["@onednn//:onednn"], + ":gemma_onednn_brgemm": ["@onednn//:onednn"], "//conditions:default": [], }), ) diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 2e1d2882..11375364 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -131,7 +131,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // Only record times after autotuning finished. bool done = per_key->autotune.Best(); -#if GEMMA_ONEDNN +#if GEMMA_ONEDNN_BRGEMM_BRGEMM done = done || per_key->brgemm_autotune.Best(); #endif if (done) times.push_back(elapsed); diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h index 266191cc..e1150977 100644 --- a/ops/brgemm-inl.h +++ b/ops/brgemm-inl.h @@ -15,7 +15,7 @@ // BRGeMM dispatch. Included from matmul-inl.h inside gcpp::HWY_NAMESPACE. -#if GEMMA_ONEDNN +#if GEMMA_ONEDNN_BRGEMM_BRGEMM static bool MakeBrgemm(dnnl::ukernel::brgemm& brg, int64_t m, int64_t n, int64_t k, int64_t batch, int64_t lda, int64_t ldb, @@ -489,4 +489,4 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( main_bufs.hw_ctx_kernel = nullptr; } -#endif // GEMMA_ONEDNN +#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM diff --git a/ops/brgemm.h b/ops/brgemm.h index 1258a3d6..74a4be0d 100644 --- a/ops/brgemm.h +++ b/ops/brgemm.h @@ -14,7 +14,7 @@ // limitations under the License. // OneDNN BRGeMM micro-kernel integration for MatMul on Intel AMX/AVX-512. -// Enabled at runtime via GEMMA_USE_ONEDNN_BRGEMM=1. +// Enabled at compile time via GEMMA_ONEDNN_BRGEMM_BRGEMM=1 (Bazel: --define gemma_onednn_brgemm=1). #ifndef THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ #define THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ @@ -23,30 +23,21 @@ #include #include -#include #include #include #include #include "hwy/base.h" -#if GEMMA_ONEDNN +#if GEMMA_ONEDNN_BRGEMM_BRGEMM #include #include "oneapi/dnnl/dnnl.hpp" #include "oneapi/dnnl/dnnl_ukernel.hpp" -#endif // GEMMA_ONEDNN +#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM namespace gcpp { -inline bool UseOneDnnBrgemm() { - static const bool enabled = [] { - const char* env = std::getenv("GEMMA_USE_ONEDNN_BRGEMM"); - return env != nullptr && env[0] == '1' && env[1] == '\0'; - }(); - return enabled; -} - struct BRGeMMConfig { int64_t M_blk; int64_t N_blk; @@ -55,7 +46,7 @@ struct BRGeMMConfig { int64_t par_m; }; -#if GEMMA_ONEDNN +#if GEMMA_ONEDNN_BRGEMM_BRGEMM // Generates autotuning candidates. Fixed: N_blk=32, K_blk=32 (AMX BF16). // Tunable: M_blk in {32,64}, batch_size in {16,32,64,128,256}. @@ -290,7 +281,7 @@ inline auto& GetBRGeMMPackedBCache() { return cache; } -#endif // GEMMA_ONEDNN +#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM } // namespace gcpp diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index a15a2fe3..e20331b6 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -47,9 +47,9 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -#if GEMMA_ONEDNN +#if GEMMA_ONEDNN_BRGEMM_BRGEMM #include "ops/brgemm-inl.h" // DoMatMul_BRGeMM -#endif // GEMMA_ONEDNN +#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM // Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register. template > @@ -1081,11 +1081,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMPerKey& per_key = MMImpl::FindOrAddPerKey( M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]); -#if GEMMA_ONEDNN +#if GEMMA_ONEDNN_BRGEMM_BRGEMM // BRGeMM path for BF16×BF16 on Intel AMX/AVX-512. // Requires M,N,K >= 32 and K % 32 == 0 (AMX tile constraint). if constexpr (IsBF16() && IsBF16()) { - if (UseOneDnnBrgemm() && M >= 32 && N >= 32 && K >= 32 && (K % 32) == 0) { + if (M >= 32 && N >= 32 && K >= 32 && (K % 32) == 0) { const float scale = A.Scale() * B.Scale(); MMAutoTune& brg_tuner = per_key.brgemm_autotune; @@ -1119,7 +1119,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, return &per_key; } } // if constexpr BF16/float -#endif // GEMMA_ONEDNN +#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM // (Also auto-tunes, hence outside the timed section to prevent interference.) const StridedViewBF A_view = diff --git a/ops/matmul.h b/ops/matmul.h index c03ec53f..b715d2ad 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -32,7 +32,7 @@ #include "hwy/base.h" #include "hwy/bit_set.h" #include "hwy/profiler.h" -#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN +#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN_BRGEMM_BRGEMM // IWYU pragma: end_exports namespace gcpp { @@ -640,9 +640,9 @@ class MMKeys { struct MMPerKey { MMAutoTune autotune; MMAutoTune autotune_par_a; -#if GEMMA_ONEDNN +#if GEMMA_ONEDNN_BRGEMM_BRGEMM MMAutoTune brgemm_autotune; -#endif // GEMMA_ONEDNN +#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM }; // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive From 0dde315370d438c71c412a555d4c8510b368c946 Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Thu, 23 Apr 2026 21:57:52 +0000 Subject: [PATCH 05/11] Adding the cmake based build support for oneDNN BGGeMM --- CMakeLists.txt | 31 +++++++++++++++++++++++++++++++ ops/bench_matmul.cc | 2 +- ops/brgemm-inl.h | 4 ++-- ops/brgemm.h | 10 +++++----- ops/matmul-inl.h | 8 ++++---- ops/matmul.h | 6 +++--- 6 files changed, 46 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 52dc7ca7..2ca0c433 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,10 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +# Optional: OneDNN BRGeMM micro-kernel support (x86-64 only). +# Enable with: cmake -DGEMMA_ONEDNN_BRGEMM=ON ... +option(GEMMA_ONEDNN_BRGEMM "Enable OneDNN BRGeMM micro-kernel for MatMul (x86-64)" OFF) + if(EMSCRIPTEN) add_compile_options("-sMEMORY64") add_compile_options("-msimd128") @@ -85,6 +89,23 @@ if(EMSCRIPTEN) target_compile_options(benchmark PRIVATE -Wno-c2y-extensions) endif() +# OneDNN BRGeMM micro-kernel support (optional, x86-64 only). +if(GEMMA_ONEDNN_BRGEMM) + set(DNNL_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(DNNL_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) + set(DNNL_CPU_RUNTIME "SEQ" CACHE STRING "" FORCE) + set(DNNL_GPU_RUNTIME "NONE" CACHE STRING "" FORCE) + set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "" FORCE) + set(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE) + FetchContent_Declare(onednn + GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN.git + GIT_TAG v3.11 + EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(onednn) + message(STATUS "OneDNN BRGeMM micro-kernel support enabled") +endif() + # Base source files set(SOURCES compression/compress-inl.h @@ -141,6 +162,8 @@ set(SOURCES ops/matmul-inl.h ops/matmul.cc ops/matmul.h + ops/brgemm.h + ops/brgemm-inl.h ops/ops-inl.h ops/ops.h ops/sum-inl.h @@ -191,6 +214,10 @@ target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static) target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR}) target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) +if(GEMMA_ONEDNN_BRGEMM) + target_compile_definitions(libgemma PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL) + target_link_libraries(libgemma dnnl) +endif() install(TARGETS libgemma DESTINATION lib) # Shared library target for C# interop @@ -215,6 +242,10 @@ target_compile_definitions(gemma_shared $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX> ) target_compile_options(gemma_shared PRIVATE $<$:-Wno-deprecated-declarations>) +if(GEMMA_ONEDNN_BRGEMM) + target_compile_definitions(gemma_shared PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL) + target_link_libraries(gemma_shared PRIVATE dnnl) +endif() install(TARGETS gemma_shared DESTINATION lib) install(FILES gemma/c_api.h DESTINATION include/gemma) install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma) diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 11375364..e9432276 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -131,7 +131,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // Only record times after autotuning finished. bool done = per_key->autotune.Best(); -#if GEMMA_ONEDNN_BRGEMM_BRGEMM +#if GEMMA_ONEDNN_BRGEMM done = done || per_key->brgemm_autotune.Best(); #endif if (done) times.push_back(elapsed); diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h index e1150977..78137a3f 100644 --- a/ops/brgemm-inl.h +++ b/ops/brgemm-inl.h @@ -15,7 +15,7 @@ // BRGeMM dispatch. Included from matmul-inl.h inside gcpp::HWY_NAMESPACE. -#if GEMMA_ONEDNN_BRGEMM_BRGEMM +#if GEMMA_ONEDNN_BRGEMM static bool MakeBrgemm(dnnl::ukernel::brgemm& brg, int64_t m, int64_t n, int64_t k, int64_t batch, int64_t lda, int64_t ldb, @@ -489,4 +489,4 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( main_bufs.hw_ctx_kernel = nullptr; } -#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM +#endif // GEMMA_ONEDNN_BRGEMM diff --git a/ops/brgemm.h b/ops/brgemm.h index 74a4be0d..38e05509 100644 --- a/ops/brgemm.h +++ b/ops/brgemm.h @@ -14,7 +14,7 @@ // limitations under the License. // OneDNN BRGeMM micro-kernel integration for MatMul on Intel AMX/AVX-512. -// Enabled at compile time via GEMMA_ONEDNN_BRGEMM_BRGEMM=1 (Bazel: --define gemma_onednn_brgemm=1). +// Enabled at compile time via GEMMA_ONEDNN_BRGEMM=1 (Bazel: --define gemma_onednn_brgemm=1). #ifndef THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ #define THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ @@ -29,12 +29,12 @@ #include "hwy/base.h" -#if GEMMA_ONEDNN_BRGEMM_BRGEMM +#if GEMMA_ONEDNN_BRGEMM #include #include "oneapi/dnnl/dnnl.hpp" #include "oneapi/dnnl/dnnl_ukernel.hpp" -#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM +#endif // GEMMA_ONEDNN_BRGEMM namespace gcpp { @@ -46,7 +46,7 @@ struct BRGeMMConfig { int64_t par_m; }; -#if GEMMA_ONEDNN_BRGEMM_BRGEMM +#if GEMMA_ONEDNN_BRGEMM // Generates autotuning candidates. Fixed: N_blk=32, K_blk=32 (AMX BF16). // Tunable: M_blk in {32,64}, batch_size in {16,32,64,128,256}. @@ -281,7 +281,7 @@ inline auto& GetBRGeMMPackedBCache() { return cache; } -#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM +#endif // GEMMA_ONEDNN_BRGEMM } // namespace gcpp diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index e20331b6..fda9d821 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -47,9 +47,9 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -#if GEMMA_ONEDNN_BRGEMM_BRGEMM +#if GEMMA_ONEDNN_BRGEMM #include "ops/brgemm-inl.h" // DoMatMul_BRGeMM -#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM +#endif // GEMMA_ONEDNN_BRGEMM // Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register. template > @@ -1081,7 +1081,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMPerKey& per_key = MMImpl::FindOrAddPerKey( M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]); -#if GEMMA_ONEDNN_BRGEMM_BRGEMM +#if GEMMA_ONEDNN_BRGEMM // BRGeMM path for BF16×BF16 on Intel AMX/AVX-512. // Requires M,N,K >= 32 and K % 32 == 0 (AMX tile constraint). if constexpr (IsBF16() && IsBF16()) { @@ -1119,7 +1119,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, return &per_key; } } // if constexpr BF16/float -#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM +#endif // GEMMA_ONEDNN_BRGEMM // (Also auto-tunes, hence outside the timed section to prevent interference.) const StridedViewBF A_view = diff --git a/ops/matmul.h b/ops/matmul.h index b715d2ad..4724bad9 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -32,7 +32,7 @@ #include "hwy/base.h" #include "hwy/bit_set.h" #include "hwy/profiler.h" -#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN_BRGEMM_BRGEMM +#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN_BRGEMM // IWYU pragma: end_exports namespace gcpp { @@ -640,9 +640,9 @@ class MMKeys { struct MMPerKey { MMAutoTune autotune; MMAutoTune autotune_par_a; -#if GEMMA_ONEDNN_BRGEMM_BRGEMM +#if GEMMA_ONEDNN_BRGEMM MMAutoTune brgemm_autotune; -#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM +#endif // GEMMA_ONEDNN_BRGEMM }; // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive From fd3b1199f10d4c95cc4beb176b8bdc2a60e0a7e7 Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Tue, 5 May 2026 19:00:15 +0000 Subject: [PATCH 06/11] fixed dtypes and syntax divergence from codebase --- ops/brgemm-inl.h | 187 +++++++++++++++++++++++++++++------------------ ops/brgemm.h | 80 ++++++++++---------- ops/matmul-inl.h | 16 ++-- 3 files changed, 162 insertions(+), 121 deletions(-) diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h index 78137a3f..08d8856c 100644 --- a/ops/brgemm-inl.h +++ b/ops/brgemm-inl.h @@ -13,7 +13,36 @@ // See the License for the specific language governing permissions and // limitations under the License. -// BRGeMM dispatch. Included from matmul-inl.h inside gcpp::HWY_NAMESPACE. +// BRGeMM dispatch for BF16 MatMul on Intel AMX/AVX-512. + +#include +#include + +#include +#include +#include + +#include "ops/brgemm.h" +#include "ops/matmul.h" +#include "util/mat.h" +#include "util/threading_context.h" +#include "util/zones.h" +#include "hwy/base.h" + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; #if GEMMA_ONEDNN_BRGEMM @@ -55,8 +84,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( ke.K_blk = cfg.K_blk; ke.N_blk = cfg.N_blk; - ke.M_blk = - static_cast(std::min(static_cast(cfg.M_blk), M)); + ke.M_blk = std::min(cfg.M_blk, M); ke.M_tail = M % ke.M_blk; ke.N_tail = N % ke.N_blk; @@ -97,10 +125,13 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( ke.m_sizes[1] = ke.M_tail ? ke.M_tail : ke.M_blk; ke.n_sizes[0] = ke.N_blk; ke.n_sizes[1] = ke.N_tail ? ke.N_tail : ke.N_blk; - const int64_t ldb_for[2] = {ke.N_blk, ke.N_tail ? ke.N_tail : ke.N_blk}; - const int64_t ldc_for[2] = {ke.N_blk, ke.N_tail ? ke.N_tail : ke.N_blk}; + const int64_t ldb_for[2] = {static_cast(ke.N_blk), + static_cast(ke.N_tail ? ke.N_tail : ke.N_blk)}; + const int64_t ldc_for[2] = {static_cast(ke.N_blk), + static_cast(ke.N_tail ? ke.N_tail : ke.N_blk)}; - // Create brgemm kernels for each (M-tile, N-tile) variant. + // Create brgemm kernels for full/tail M and N tile sizes. + // mi=0 is the full M tile, mi=1 is the M-tail; likewise for ni and N. size_t max_sp = 0; for (int mi = 0; mi < 2; ++mi) { for (int ni = 0; ni < 2; ++ni) { @@ -109,22 +140,25 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( if (mi == 0 && ke.M_full_tiles == 0) continue; if (ni == 0 && ke.N_full_tiles == 0) continue; - const int64_t ms = ke.m_sizes[mi]; - const int64_t ns = ke.n_sizes[ni]; + const int64_t ms = static_cast(ke.m_sizes[mi]); + const int64_t ns = static_cast(ke.n_sizes[ni]); if (ke.K_chunks > 0) { - if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns, ke.K_blk, - ke.K_super_size, ke.lda, ldb_for[ni], ldc_for[ni], - a_dt, b_dt, c_dt, false)) { + if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns, + static_cast(ke.K_blk), + static_cast(ke.K_super_size), ke.lda, + ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, + false)) { return; } max_sp = std::max(max_sp, ke.brg_first_all[mi][ni].get_scratchpad_size()); } if (ke.K_super_blocks > 1) { - if (!MakeBrgemm(ke.brg_full[mi][ni], ms, ns, ke.K_blk, - ke.batch_full, ke.lda, ldb_for[ni], ldc_for[ni], - a_dt, b_dt, c_dt, true)) { + if (!MakeBrgemm(ke.brg_full[mi][ni], ms, ns, + static_cast(ke.K_blk), + static_cast(ke.batch_full), ke.lda, + ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, true)) { return; } max_sp = @@ -134,7 +168,8 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( const bool rem_is_first = (ke.K_super_blocks == 0); auto& target = rem_is_first ? ke.brg_first_rem[mi][ni] : ke.brg_rem[mi][ni]; - if (!MakeBrgemm(target, ms, ns, ke.K_blk, ke.batch_rem, ke.lda, + if (!MakeBrgemm(target, ms, ns, static_cast(ke.K_blk), + static_cast(ke.batch_rem), ke.lda, ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, !rem_is_first)) { return; @@ -143,7 +178,8 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( } if (ke.K_tail > 0) { const bool add_c = (ke.K_chunks > 0); - if (!MakeBrgemm(ke.brg_ktail[mi][ni], ms, ns, ke.K_tail, 1, ke.lda, + if (!MakeBrgemm(ke.brg_ktail[mi][ni], ms, ns, + static_cast(ke.K_tail), 1, ke.lda, ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, add_c)) { return; @@ -161,15 +197,17 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( if (ni == 1 && ke.N_tail == 0) continue; if (ni == 0 && ke.N_full_tiles == 0) continue; - const int64_t ns = ke.n_sizes[ni]; + const int64_t ns = static_cast(ke.n_sizes[ni]); if (ke.K_chunks > 0) { - const int64_t K_full = ke.K_chunks * ke.K_blk; + const int64_t K_full = + static_cast(ke.K_chunks * ke.K_blk); try { ke.pack_B[ni] = transform(K_full, ns, pack_type::trans, ke.ldb_orig, ldb_for[ni], b_dt, b_dt); if (!ke.pack_B[ni]) return; ke.pack_B[ni].generate(); - ke.blocked_B_size[ni] = ldb_for[ni] * ke.K_blk * ke.b_dt_size; + ke.blocked_B_size[ni] = static_cast(ldb_for[ni]) * + ke.K_blk * ke.b_dt_size; } catch (...) { return; } @@ -177,12 +215,12 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( if (ke.K_tail > 0) { try { ke.pack_B_ktail[ni] = transform( - ke.K_tail, ns, pack_type::trans, ke.ldb_orig, ldb_for[ni], - b_dt, b_dt); + static_cast(ke.K_tail), ns, pack_type::trans, + ke.ldb_orig, ldb_for[ni], b_dt, b_dt); if (!ke.pack_B_ktail[ni]) return; ke.pack_B_ktail[ni].generate(); ke.blocked_B_ktail_size[ni] = - ldb_for[ni] * ke.K_tail * ke.b_dt_size; + static_cast(ldb_for[ni]) * ke.K_tail * ke.b_dt_size; } catch (...) { return; } @@ -194,55 +232,55 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( for (int ni = 0; ni < 2; ++ni) { if (ni == 1 && ke.N_tail == 0) continue; if (ni == 0 && ke.N_full_tiles == 0) continue; - const int64_t cur_n = ke.n_sizes[ni]; + const size_t cur_n = ke.n_sizes[ni]; if (ke.K_chunks > 0) { ke.offsets_first_all[ni].resize(ke.K_super_size); - for (int64_t i = 0; i < ke.K_super_size; ++i) { + for (size_t i = 0; i < ke.K_super_size; ++i) { const int64_t a_off = - i * ke.K_blk * static_cast(ke.a_dt_size); + static_cast(i * ke.K_blk * ke.a_dt_size); const int64_t b_off = ke.need_pack - ? i * static_cast(ke.blocked_B_size[ni]) - : i * cur_n * ke.K_blk * static_cast(ke.b_dt_size); + ? static_cast(i * ke.blocked_B_size[ni]) + : static_cast(i * cur_n * ke.K_blk * ke.b_dt_size); ke.offsets_first_all[ni][i] = {a_off, b_off}; } } if (ke.K_super_blocks > 1) { ke.offsets_full[ni].resize(ke.K_super_blocks - 1); - for (int64_t ks = 1; ks < ke.K_super_blocks; ++ks) { + for (size_t ks = 1; ks < ke.K_super_blocks; ++ks) { auto& tbl = ke.offsets_full[ni][ks - 1]; tbl.resize(ke.batch_full); - const int64_t k_start = ks * ke.K_super_size; - for (int64_t i = 0; i < ke.batch_full; ++i) { - const int64_t k_idx = k_start + i; + const size_t k_start = ks * ke.K_super_size; + for (size_t i = 0; i < ke.batch_full; ++i) { + const size_t k_idx = k_start + i; const int64_t a_off = - k_idx * ke.K_blk * static_cast(ke.a_dt_size); + static_cast(k_idx * ke.K_blk * ke.a_dt_size); const int64_t b_off = ke.need_pack - ? k_idx * static_cast(ke.blocked_B_size[ni]) - : k_idx * cur_n * ke.K_blk * - static_cast(ke.b_dt_size); + ? static_cast(k_idx * ke.blocked_B_size[ni]) + : static_cast(k_idx * cur_n * ke.K_blk * + ke.b_dt_size); tbl[i] = {a_off, b_off}; } } } if (ke.K_super_rem > 0) { - const int64_t k_base = ke.K_super_blocks * ke.K_super_size; + const size_t k_base = ke.K_super_blocks * ke.K_super_size; auto& rem_tbl = (ke.K_super_blocks == 0) ? ke.offsets_first_rem[ni] : ke.offsets_rem[ni]; rem_tbl.resize(ke.K_super_rem); - for (int64_t i = 0; i < ke.K_super_rem; ++i) { - const int64_t k_idx = k_base + i; + for (size_t i = 0; i < ke.K_super_rem; ++i) { + const size_t k_idx = k_base + i; const int64_t a_off = - k_idx * ke.K_blk * static_cast(ke.a_dt_size); + static_cast(k_idx * ke.K_blk * ke.a_dt_size); const int64_t b_off = ke.need_pack - ? k_idx * static_cast(ke.blocked_B_size[ni]) - : k_idx * cur_n * ke.K_blk * - static_cast(ke.b_dt_size); + ? static_cast(k_idx * ke.blocked_B_size[ni]) + : static_cast(k_idx * cur_n * ke.K_blk * + ke.b_dt_size); rem_tbl[i] = {a_off, b_off}; } } @@ -270,7 +308,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( if (ke.need_pack) { size_t total_packed = 0; - for (int64_t nt = 0; nt < ke.N_total_tiles; ++nt) { + for (size_t nt = 0; nt < ke.N_total_tiles; ++nt) { const int ni = (nt < ke.N_full_tiles) ? 0 : 1; pe.B_tile_offset[nt] = total_packed; if (ke.K_chunks > 0) @@ -283,13 +321,13 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( uint8_t* B_packed = pe.B_packed_buf.data(); if (!B_packed) return; - for (int64_t nt = 0; nt < ke.N_total_tiles; ++nt) { + for (size_t nt = 0; nt < ke.N_total_tiles; ++nt) { const int ni = (nt < ke.N_full_tiles) ? 0 : 1; - const int64_t b_row = (nt < ke.N_full_tiles) - ? nt * ke.N_blk - : ke.N_full_tiles * ke.N_blk; + const size_t b_row = (nt < ke.N_full_tiles) + ? nt * ke.N_blk + : ke.N_full_tiles * ke.N_blk; const uint8_t* B_in = - B_base + b_row * ke.ldb_orig * ke.b_dt_size; + B_base + b_row * static_cast(ke.ldb_orig) * ke.b_dt_size; try { if (ke.K_chunks > 0) { @@ -320,14 +358,14 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( // Execute one (m, n) tile for a given K-super-block. const auto execute_tile = [&](size_t m_start, size_t n_start, - int64_t k_super, float* temp_C, + size_t k_super, float* temp_C, uint8_t* scratch) HWY_ATTR { - const int64_t m_tile_idx = m_start / ke.M_blk; - const int64_t n_tile_idx = n_start / ke.N_blk; + const size_t m_tile_idx = m_start / ke.M_blk; + const size_t n_tile_idx = n_start / ke.N_blk; const int mi = (m_tile_idx < ke.M_full_tiles) ? 0 : 1; const int ni = (n_tile_idx < ke.N_full_tiles) ? 0 : 1; - const int64_t cur_m = ke.m_sizes[mi]; - const int64_t cur_n = ke.n_sizes[ni]; + const size_t cur_m = ke.m_sizes[mi]; + const size_t cur_n = ke.n_sizes[ni]; const size_t real_m = (m_tile_idx < ke.M_full_tiles) ? m_tile_idx * ke.M_blk @@ -336,16 +374,18 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( ? n_tile_idx * ke.N_blk : ke.N_full_tiles * ke.N_blk; - const uint8_t* A_tile = A_base + real_m * ke.lda * ke.a_dt_size; + const uint8_t* A_tile = + A_base + real_m * static_cast(ke.lda) * ke.a_dt_size; const void* B_tile = ke.need_pack ? static_cast(B_packed + pe.B_tile_offset[n_tile_idx]) - : static_cast(B_base + - real_n * ke.ldb_orig * ke.b_dt_size); + : static_cast( + B_base + + real_n * static_cast(ke.ldb_orig) * ke.b_dt_size); float* C_tile_ptr = temp_C; - const int64_t k_total = + const size_t k_total = ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0); if (k_super < ke.K_super_blocks) { @@ -379,7 +419,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( ? static_cast(B_packed + pe.B_ktail_offset[n_tile_idx]) : static_cast( - B_base + (real_n * ke.ldb_orig + + B_base + (real_n * static_cast(ke.ldb_orig) + ke.K_chunks * ke.K_blk) * ke.b_dt_size); ke.brg_ktail[mi][ni].execute(A_ktail, const_cast(B_ktail), @@ -390,19 +430,18 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( const hn::ScalableTag df; const auto vscale = hn::Set(df, scale); const size_t lanes = hn::Lanes(df); - for (int64_t m = 0; m < cur_m; ++m) { + for (size_t m = 0; m < cur_m; ++m) { TC* C_row = C.Row(real_m + m) + real_n; const float* t_row = C_tile_ptr + m * cur_n; const float* add_row = add ? add + real_n : nullptr; - int64_t n = 0; + size_t n = 0; if (add_row) { - for (; n + static_cast(lanes) <= cur_n; - n += static_cast(lanes)) { + for (; n + lanes <= cur_n; n += lanes) { const auto v = hn::Load(df, t_row + n); const auto va = hn::Load(df, add_row + n); const auto result = hn::MulAdd(v, vscale, va); if constexpr (hwy::IsSame()) { - hn::Store(result, df, reinterpret_cast(C_row) + n); + hn::Store(result, df, HWY_RCAST_ALIGNED(float*, C_row) + n); } else { const hn::Rebind dc; hn::Store(hn::DemoteTo(dc, result), dc, C_row + n); @@ -413,12 +452,11 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( C_row[n] = hwy::ConvertScalarTo(val); } } else { - for (; n + static_cast(lanes) <= cur_n; - n += static_cast(lanes)) { + for (; n + lanes <= cur_n; n += lanes) { const auto v = hn::Load(df, t_row + n); const auto result = hn::Mul(v, vscale); if constexpr (hwy::IsSame()) { - hn::Store(result, df, reinterpret_cast(C_row) + n); + hn::Store(result, df, HWY_RCAST_ALIGNED(float*, C_row) + n); } else { const hn::Rebind dc; hn::Store(hn::DemoteTo(dc, result), dc, C_row + n); @@ -434,9 +472,9 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( }; // Parallel dispatch: K-super outer, N middle, M inner (keeps B in L2). - const int64_t k_total_supers = + const size_t k_total_supers = ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0); - const int64_t k_iters = (k_total_supers > 0) ? k_total_supers : 1; + const size_t k_iters = (k_total_supers > 0) ? k_total_supers : size_t{1}; const size_t num_threads = ctx.pools.MaxWorkersPerCluster(); const size_t total_n_tiles = ke.N_total_tiles; @@ -466,12 +504,11 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( const size_t total_tc = total_m_tiles * n_tiles_in_range; float* tc_base = tbufs.EnsureTempC(total_tc); - for (int64_t ks = 0; ks < k_iters; ++ks) { + for (size_t ks = 0; ks < k_iters; ++ks) { size_t n_idx = 0; for (size_t nt = n_begin; nt < n_end; ++nt) { const size_t n = nt * ke.N_blk; - for (int64_t mt = 0; mt < static_cast(total_m_tiles); - ++mt) { + for (size_t mt = 0; mt < total_m_tiles; ++mt) { const size_t m = mt * ke.M_blk; float* temp_C = tc_base + (mt * n_tiles_in_range + n_idx) * @@ -485,8 +522,14 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( dnnl::ukernel::brgemm::release_hw_context(); auto& main_bufs = GetBRGeMMThreadBufs(); - main_bufs.hw_ctx_set = false; main_bufs.hw_ctx_kernel = nullptr; } #endif // GEMMA_ONEDNN_BRGEMM + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/ops/brgemm.h b/ops/brgemm.h index 38e05509..07190aba 100644 --- a/ops/brgemm.h +++ b/ops/brgemm.h @@ -39,11 +39,11 @@ namespace gcpp { struct BRGeMMConfig { - int64_t M_blk; - int64_t N_blk; - int64_t K_blk; - int64_t batch_size; - int64_t par_m; + size_t M_blk; + size_t N_blk = 32; + size_t K_blk = 32; + size_t batch_size; + size_t par_m; }; #if GEMMA_ONEDNN_BRGEMM @@ -53,18 +53,19 @@ struct BRGeMMConfig { inline std::vector BRGeMMCandidates(size_t M, size_t K, size_t N) { std::vector out; - static constexpr int64_t kNBlk = 32; - static constexpr int64_t kKBlk = 32; - static constexpr int64_t kMBlkValues[] = {32, 64}; - static constexpr int64_t kBatchValues[] = {16, 32, 64, 128, 256}; - - const int64_t k_chunks = static_cast(K) / kKBlk; - for (int64_t mb : kMBlkValues) { - if (mb > static_cast(M)) continue; - if (kNBlk > static_cast(N)) continue; - for (int64_t bs : kBatchValues) { - const int64_t eff_bs = - (k_chunks > 0) ? std::min(bs, k_chunks) : int64_t{1}; + out.reserve(10); // At most 2 M_blk * 5 batch_size candidates. + static constexpr size_t kNBlk = 32; + static constexpr size_t kKBlk = 32; + static constexpr size_t kMBlkValues[] = {32, 64}; + static constexpr size_t kBatchValues[] = {16, 32, 64, 128, 256}; + + const size_t k_chunks = K / kKBlk; + for (size_t mb : kMBlkValues) { + if (mb > M) continue; + if (kNBlk > N) continue; + for (size_t bs : kBatchValues) { + const size_t eff_bs = + (k_chunks > 0) ? std::min(bs, k_chunks) : size_t{1}; bool dup = false; for (const auto& c : out) { if (c.M_blk == mb && c.batch_size == eff_bs) { @@ -77,8 +78,8 @@ inline std::vector BRGeMMCandidates(size_t M, size_t K, } } if (out.empty()) { - out.push_back({static_cast(std::min(M, size_t{32})), - static_cast(std::min(N, size_t{32})), 32, 1, 1}); + out.push_back({std::min(M, size_t{32}), + std::min(N, size_t{32}), 32, 1, 1}); } return out; } @@ -123,7 +124,8 @@ class HugePageBuffer { } madvise(ptr_, size_, MADV_HUGEPAGE); for (size_t off = 0; off < size_; off += kHugePageSize) { - static_cast(ptr_)[off] = 0; + ptr_[off] = 0; + hwy::PreventElision(ptr_[off]); } } @@ -139,7 +141,7 @@ class HugePageBuffer { // Kernel cache key: identifies a JIT-compiled kernel set. struct BRGeMMKernelKey { size_t M, K, N; - int64_t M_blk, N_blk, K_blk, batch_size; + size_t M_blk, N_blk, K_blk, batch_size; bool operator==(const BRGeMMKernelKey& o) const { return M == o.M && K == o.K && N == o.N && M_blk == o.M_blk && N_blk == o.N_blk && K_blk == o.K_blk && batch_size == o.batch_size; @@ -152,25 +154,25 @@ struct BRGeMMKernelKeyHash { h = (h ^ k.M) * 1099511628211ULL; h = (h ^ k.K) * 1099511628211ULL; h = (h ^ k.N) * 1099511628211ULL; - h = (h ^ static_cast(k.M_blk)) * 1099511628211ULL; - h = (h ^ static_cast(k.N_blk)) * 1099511628211ULL; - h = (h ^ static_cast(k.K_blk)) * 1099511628211ULL; - h = (h ^ static_cast(k.batch_size)) * 1099511628211ULL; + h = (h ^ k.M_blk) * 1099511628211ULL; + h = (h ^ k.N_blk) * 1099511628211ULL; + h = (h ^ k.K_blk) * 1099511628211ULL; + h = (h ^ k.batch_size) * 1099511628211ULL; return h; } }; // Cached JIT-compiled kernels with precomputed tile parameters and offsets. struct BRGeMMKernelEntry { - int64_t M_blk, N_blk, K_blk; - int64_t M_tail, N_tail, K_tail; - int64_t K_chunks; - int64_t M_full_tiles, N_full_tiles; - int64_t M_total_tiles, N_total_tiles; - int64_t K_super_size, K_super_blocks; - int64_t K_super_rem; - int64_t batch_full, batch_rem; - int64_t m_sizes[2], n_sizes[2]; + size_t M_blk, N_blk, K_blk; + size_t M_tail, N_tail, K_tail; + size_t K_chunks; + size_t M_full_tiles, N_full_tiles; + size_t M_total_tiles, N_total_tiles; + size_t K_super_size, K_super_blocks; + size_t K_super_rem; + size_t batch_full, batch_rem; + size_t m_sizes[2], n_sizes[2]; int64_t lda; int64_t ldb_orig; bool need_pack; @@ -203,7 +205,7 @@ struct BRGeMMKernelEntry { struct BRGeMMPackedBKey { uintptr_t B_ptr; size_t K, N; - int64_t K_blk, N_blk; + size_t K_blk, N_blk; bool operator==(const BRGeMMPackedBKey& o) const { return B_ptr == o.B_ptr && K == o.K && N == o.N && K_blk == o.K_blk && N_blk == o.N_blk; @@ -216,8 +218,8 @@ struct BRGeMMPackedBKeyHash { h = (h ^ k.B_ptr) * 1099511628211ULL; h = (h ^ k.K) * 1099511628211ULL; h = (h ^ k.N) * 1099511628211ULL; - h = (h ^ static_cast(k.K_blk)) * 1099511628211ULL; - h = (h ^ static_cast(k.N_blk)) * 1099511628211ULL; + h = (h ^ k.K_blk) * 1099511628211ULL; + h = (h ^ k.N_blk) * 1099511628211ULL; return h; } }; @@ -234,7 +236,6 @@ struct BRGeMMThreadBufs { std::vector scratch; std::vector tc_storage; - bool hw_ctx_set = false; const void* hw_ctx_kernel = nullptr; uint8_t* EnsureScratch(size_t size) { @@ -253,9 +254,8 @@ struct BRGeMMThreadBufs { void MaybeSetHwContext(const dnnl::ukernel::brgemm& brg) { const void* brg_ptr = &brg; - if (!hw_ctx_set || hw_ctx_kernel != brg_ptr) { + if (hw_ctx_kernel != brg_ptr) { brg.set_hw_context(); - hw_ctx_set = true; hw_ctx_kernel = brg_ptr; } } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index fda9d821..ce023f42 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -41,16 +41,15 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" +#if GEMMA_ONEDNN_BRGEMM +#include "ops/brgemm-inl.h" +#endif // GEMMA_ONEDNN_BRGEMM HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -#if GEMMA_ONEDNN_BRGEMM -#include "ops/brgemm-inl.h" // DoMatMul_BRGeMM -#endif // GEMMA_ONEDNN_BRGEMM - // Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register. template > static hn::VFromD FastPromoteOddTo(DF df, hn::VFromD vbf) { @@ -1110,11 +1109,10 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, if (HWY_UNLIKELY(env.print_best && brg_tuner.Best())) { const BRGeMMConfig& best = *brg_tuner.Best(); fprintf(stderr, - "BRGeMM best: %zux%zux%zu M_blk=%ld N_blk=%ld K_blk=%ld " - "batch=%ld\n", - M, K, N, static_cast(best.M_blk), - static_cast(best.N_blk), static_cast(best.K_blk), - static_cast(best.batch_size)); + "BRGeMM best: %zux%zux%zu M_blk=%zu N_blk=%zu K_blk=%zu " + "batch=%zu\n", + M, K, N, best.M_blk, best.N_blk, best.K_blk, + best.batch_size); } return &per_key; } From 66400214c6dd45477a57e31f2f0a76d14268c0aa Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Tue, 5 May 2026 21:26:39 +0000 Subject: [PATCH 07/11] changed lda and ldb to size_t. Added conversions inplace for brgemm and transform inits --- ops/brgemm-inl.h | 37 ++++++++++++++++++++++--------------- ops/brgemm.h | 4 ++-- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h index 08d8856c..1e29477b 100644 --- a/ops/brgemm-inl.h +++ b/ops/brgemm-inl.h @@ -90,6 +90,9 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( ke.N_tail = N % ke.N_blk; ke.K_tail = K % ke.K_blk; + // Floor division: K_tail remainder is handled by a dedicated brg_ktail + // kernel rather than padding K, avoiding extra memory writes to zero-pad + // A and B along the K dimension. ke.K_chunks = K / ke.K_blk; ke.N_full_tiles = N / ke.N_blk; ke.M_full_tiles = M / ke.M_blk; @@ -146,9 +149,9 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( if (ke.K_chunks > 0) { if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns, static_cast(ke.K_blk), - static_cast(ke.K_super_size), ke.lda, - ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, - false)) { + static_cast(ke.K_super_size), + static_cast(ke.lda), ldb_for[ni], + ldc_for[ni], a_dt, b_dt, c_dt, false)) { return; } max_sp = std::max(max_sp, @@ -157,8 +160,9 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( if (ke.K_super_blocks > 1) { if (!MakeBrgemm(ke.brg_full[mi][ni], ms, ns, static_cast(ke.K_blk), - static_cast(ke.batch_full), ke.lda, - ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, true)) { + static_cast(ke.batch_full), + static_cast(ke.lda), ldb_for[ni], + ldc_for[ni], a_dt, b_dt, c_dt, true)) { return; } max_sp = @@ -169,8 +173,9 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( auto& target = rem_is_first ? ke.brg_first_rem[mi][ni] : ke.brg_rem[mi][ni]; if (!MakeBrgemm(target, ms, ns, static_cast(ke.K_blk), - static_cast(ke.batch_rem), ke.lda, - ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, + static_cast(ke.batch_rem), + static_cast(ke.lda), ldb_for[ni], + ldc_for[ni], a_dt, b_dt, c_dt, !rem_is_first)) { return; } @@ -179,8 +184,9 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( if (ke.K_tail > 0) { const bool add_c = (ke.K_chunks > 0); if (!MakeBrgemm(ke.brg_ktail[mi][ni], ms, ns, - static_cast(ke.K_tail), 1, ke.lda, - ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, + static_cast(ke.K_tail), 1, + static_cast(ke.lda), ldb_for[ni], + ldc_for[ni], a_dt, b_dt, c_dt, add_c)) { return; } @@ -203,7 +209,8 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( static_cast(ke.K_chunks * ke.K_blk); try { ke.pack_B[ni] = transform(K_full, ns, pack_type::trans, - ke.ldb_orig, ldb_for[ni], b_dt, b_dt); + static_cast(ke.ldb_orig), + ldb_for[ni], b_dt, b_dt); if (!ke.pack_B[ni]) return; ke.pack_B[ni].generate(); ke.blocked_B_size[ni] = static_cast(ldb_for[ni]) * @@ -216,7 +223,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( try { ke.pack_B_ktail[ni] = transform( static_cast(ke.K_tail), ns, pack_type::trans, - ke.ldb_orig, ldb_for[ni], b_dt, b_dt); + static_cast(ke.ldb_orig), ldb_for[ni], b_dt, b_dt); if (!ke.pack_B_ktail[ni]) return; ke.pack_B_ktail[ni].generate(); ke.blocked_B_ktail_size[ni] = @@ -327,7 +334,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( ? nt * ke.N_blk : ke.N_full_tiles * ke.N_blk; const uint8_t* B_in = - B_base + b_row * static_cast(ke.ldb_orig) * ke.b_dt_size; + B_base + b_row * ke.ldb_orig * ke.b_dt_size; try { if (ke.K_chunks > 0) { @@ -375,14 +382,14 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( : ke.N_full_tiles * ke.N_blk; const uint8_t* A_tile = - A_base + real_m * static_cast(ke.lda) * ke.a_dt_size; + A_base + real_m * ke.lda * ke.a_dt_size; const void* B_tile = ke.need_pack ? static_cast(B_packed + pe.B_tile_offset[n_tile_idx]) : static_cast( B_base + - real_n * static_cast(ke.ldb_orig) * ke.b_dt_size); + real_n * ke.ldb_orig * ke.b_dt_size); float* C_tile_ptr = temp_C; const size_t k_total = @@ -419,7 +426,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( ? static_cast(B_packed + pe.B_ktail_offset[n_tile_idx]) : static_cast( - B_base + (real_n * static_cast(ke.ldb_orig) + + B_base + (real_n * ke.ldb_orig + ke.K_chunks * ke.K_blk) * ke.b_dt_size); ke.brg_ktail[mi][ni].execute(A_ktail, const_cast(B_ktail), diff --git a/ops/brgemm.h b/ops/brgemm.h index 07190aba..c09b67d8 100644 --- a/ops/brgemm.h +++ b/ops/brgemm.h @@ -173,8 +173,8 @@ struct BRGeMMKernelEntry { size_t K_super_rem; size_t batch_full, batch_rem; size_t m_sizes[2], n_sizes[2]; - int64_t lda; - int64_t ldb_orig; + size_t lda; + size_t ldb_orig; bool need_pack; size_t a_dt_size, b_dt_size; size_t N_padded; From 9d6bbeef04219176a3d4031d5134fec57b19b9ec Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Tue, 5 May 2026 23:13:27 +0000 Subject: [PATCH 08/11] Replaced / and % with Divide and Remainder utils from hwy::Divisor --- ops/brgemm-inl.h | 19 +++++++++++-------- ops/brgemm.h | 4 ++++ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h index 1e29477b..68b4e066 100644 --- a/ops/brgemm-inl.h +++ b/ops/brgemm-inl.h @@ -85,17 +85,20 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( ke.K_blk = cfg.K_blk; ke.N_blk = cfg.N_blk; ke.M_blk = std::min(cfg.M_blk, M); + ke.div_M_blk = hwy::Divisor(ke.M_blk); + ke.div_N_blk = hwy::Divisor(ke.N_blk); + ke.div_K_blk = hwy::Divisor(ke.K_blk); - ke.M_tail = M % ke.M_blk; - ke.N_tail = N % ke.N_blk; - ke.K_tail = K % ke.K_blk; + ke.M_tail = ke.div_M_blk.Remainder(M); + ke.N_tail = ke.div_N_blk.Remainder(N); + ke.K_tail = ke.div_K_blk.Remainder(K); // Floor division: K_tail remainder is handled by a dedicated brg_ktail // kernel rather than padding K, avoiding extra memory writes to zero-pad // A and B along the K dimension. - ke.K_chunks = K / ke.K_blk; - ke.N_full_tiles = N / ke.N_blk; - ke.M_full_tiles = M / ke.M_blk; + ke.K_chunks = ke.div_K_blk.Divide(K); + ke.N_full_tiles = ke.div_N_blk.Divide(N); + ke.M_full_tiles = ke.div_M_blk.Divide(M); ke.N_total_tiles = ke.N_full_tiles + (ke.N_tail ? 1 : 0); ke.M_total_tiles = ke.M_full_tiles + (ke.M_tail ? 1 : 0); ke.N_padded = ke.N_total_tiles * ke.N_blk; @@ -367,8 +370,8 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( const auto execute_tile = [&](size_t m_start, size_t n_start, size_t k_super, float* temp_C, uint8_t* scratch) HWY_ATTR { - const size_t m_tile_idx = m_start / ke.M_blk; - const size_t n_tile_idx = n_start / ke.N_blk; + const size_t m_tile_idx = ke.div_M_blk.Divide(m_start); + const size_t n_tile_idx = ke.div_N_blk.Divide(n_start); const int mi = (m_tile_idx < ke.M_full_tiles) ? 0 : 1; const int ni = (n_tile_idx < ke.N_full_tiles) ? 0 : 1; const size_t cur_m = ke.m_sizes[mi]; diff --git a/ops/brgemm.h b/ops/brgemm.h index c09b67d8..a9d0f231 100644 --- a/ops/brgemm.h +++ b/ops/brgemm.h @@ -165,6 +165,10 @@ struct BRGeMMKernelKeyHash { // Cached JIT-compiled kernels with precomputed tile parameters and offsets. struct BRGeMMKernelEntry { size_t M_blk, N_blk, K_blk; + // Precomputed divisors for fast modulo/division by block sizes. + hwy::Divisor div_M_blk{1}; + hwy::Divisor div_N_blk{1}; + hwy::Divisor div_K_blk{1}; size_t M_tail, N_tail, K_tail; size_t K_chunks; size_t M_full_tiles, N_full_tiles; From 45708eac8138692919f2a32637b10abff541ad3c Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Wed, 6 May 2026 00:29:47 +0000 Subject: [PATCH 09/11] Moved the BRGeMM Kernel inits to a separate HWY_NOINLINE helper function --- ops/brgemm-inl.h | 409 ++++++++++++++++++++++++----------------------- 1 file changed, 212 insertions(+), 197 deletions(-) diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h index 68b4e066..18dadfc0 100644 --- a/ops/brgemm-inl.h +++ b/ops/brgemm-inl.h @@ -64,226 +64,202 @@ static bool MakeBrgemm(dnnl::ukernel::brgemm& brg, int64_t m, int64_t n, } } -template -static HWY_NOINLINE void DoMatMul_BRGeMM( - const MatPtrT& A, const MatPtrT& B, RowPtrs C, size_t M, - size_t K, size_t N, float scale, const float* HWY_RESTRICT add, - const BRGeMMConfig& cfg, ThreadingContext& ctx, size_t cluster_idx) { +// JIT-compiles brgemm kernels, B-packing transforms, and offset tables for +// the given matrix dimensions and tiling config. Returns false on failure. +static HWY_NOINLINE bool InitBRGeMMKernels( + const BRGeMMConfig& cfg, size_t M, size_t K, size_t N, size_t lda, + size_t ldb_orig, BRGeMMKernelEntry& ke) { using dnnl::ukernel::brgemm; using dnnl::ukernel::pack_type; using dnnl::ukernel::transform; - // Level-1 cache: kernels keyed on (M, K, N, config). - const BRGeMMKernelKey kern_key{M, K, N, cfg.M_blk, cfg.N_blk, cfg.K_blk, - cfg.batch_size}; - auto& kern_cache = GetBRGeMMKernelCache(); - auto kern_it = kern_cache.find(kern_key); + ke.K_blk = cfg.K_blk; + ke.N_blk = cfg.N_blk; + ke.M_blk = std::min(cfg.M_blk, M); + ke.div_M_blk = hwy::Divisor(ke.M_blk); + ke.div_N_blk = hwy::Divisor(ke.N_blk); + ke.div_K_blk = hwy::Divisor(ke.K_blk); + + ke.M_tail = ke.div_M_blk.Remainder(M); + ke.N_tail = ke.div_N_blk.Remainder(N); + ke.K_tail = ke.div_K_blk.Remainder(K); + + // Floor division: K_tail remainder is handled by a dedicated brg_ktail + // kernel rather than padding K, avoiding extra memory writes to zero-pad + // A and B along the K dimension. + ke.K_chunks = ke.div_K_blk.Divide(K); + ke.N_full_tiles = ke.div_N_blk.Divide(N); + ke.M_full_tiles = ke.div_M_blk.Divide(M); + ke.N_total_tiles = ke.N_full_tiles + (ke.N_tail ? 1 : 0); + ke.M_total_tiles = ke.M_full_tiles + (ke.M_tail ? 1 : 0); + ke.N_padded = ke.N_total_tiles * ke.N_blk; + + if (ke.M_total_tiles == 0 || ke.N_total_tiles == 0 || + (ke.K_chunks == 0 && ke.K_tail == 0)) { + return false; + } - if (kern_it == kern_cache.end()) { - BRGeMMKernelEntry ke; + ke.K_super_size = std::min(cfg.batch_size, ke.K_chunks); + ke.K_super_blocks = (ke.K_chunks > 0) ? ke.K_chunks / ke.K_super_size : 0; + ke.K_super_rem = (ke.K_chunks > 0) ? ke.K_chunks % ke.K_super_size : 0; + ke.batch_full = ke.K_super_size; + ke.batch_rem = ke.K_super_rem; + + const auto a_dt = dnnl::memory::data_type::bf16; + const auto b_dt = dnnl::memory::data_type::bf16; + const auto c_dt = dnnl::memory::data_type::f32; + ke.a_dt_size = dnnl::memory::data_type_size(a_dt); + ke.b_dt_size = dnnl::memory::data_type_size(b_dt); + + const auto pack = brgemm::get_B_pack_type(a_dt, b_dt); + if (pack == pack_type::undef) return false; + ke.need_pack = (pack != pack_type::no_trans); + + ke.lda = lda; + ke.ldb_orig = ldb_orig; + + // Indexed by tail flag: [0] = full tile size, [1] = tail size (or full if + // no tail). Separate kernels are JIT-compiled for full vs. tail tile widths + // along both M and N dimensions. + ke.m_sizes[0] = ke.M_blk; + ke.m_sizes[1] = ke.M_tail ? ke.M_tail : ke.M_blk; + ke.n_sizes[0] = ke.N_blk; + ke.n_sizes[1] = ke.N_tail ? ke.N_tail : ke.N_blk; + const int64_t ldb_for[2] = {static_cast(ke.N_blk), + static_cast(ke.N_tail ? ke.N_tail : ke.N_blk)}; + const int64_t ldc_for[2] = {static_cast(ke.N_blk), + static_cast(ke.N_tail ? ke.N_tail : ke.N_blk)}; + + // JIT a brgemm kernel for each (mi, ni) where mi/ni indicate whether we + // are processing the M-tail or N-tail: 0 = full block, 1 = tail block. + // Skipped when the corresponding tail is zero (no partial tile exists). + size_t max_sp = 0; + for (int mi = 0; mi < 2; ++mi) { + for (int ni = 0; ni < 2; ++ni) { + if (mi == 1 && ke.M_tail == 0) continue; + if (ni == 1 && ke.N_tail == 0) continue; + if (mi == 0 && ke.M_full_tiles == 0) continue; + if (ni == 0 && ke.N_full_tiles == 0) continue; - ke.K_blk = cfg.K_blk; - ke.N_blk = cfg.N_blk; - ke.M_blk = std::min(cfg.M_blk, M); - ke.div_M_blk = hwy::Divisor(ke.M_blk); - ke.div_N_blk = hwy::Divisor(ke.N_blk); - ke.div_K_blk = hwy::Divisor(ke.K_blk); - - ke.M_tail = ke.div_M_blk.Remainder(M); - ke.N_tail = ke.div_N_blk.Remainder(N); - ke.K_tail = ke.div_K_blk.Remainder(K); - - // Floor division: K_tail remainder is handled by a dedicated brg_ktail - // kernel rather than padding K, avoiding extra memory writes to zero-pad - // A and B along the K dimension. - ke.K_chunks = ke.div_K_blk.Divide(K); - ke.N_full_tiles = ke.div_N_blk.Divide(N); - ke.M_full_tiles = ke.div_M_blk.Divide(M); - ke.N_total_tiles = ke.N_full_tiles + (ke.N_tail ? 1 : 0); - ke.M_total_tiles = ke.M_full_tiles + (ke.M_tail ? 1 : 0); - ke.N_padded = ke.N_total_tiles * ke.N_blk; - - if (ke.M_total_tiles == 0 || ke.N_total_tiles == 0 || - (ke.K_chunks == 0 && ke.K_tail == 0)) { - return; - } + const int64_t ms = static_cast(ke.m_sizes[mi]); + const int64_t ns = static_cast(ke.n_sizes[ni]); - ke.K_super_size = std::min(cfg.batch_size, ke.K_chunks); - ke.K_super_blocks = (ke.K_chunks > 0) ? ke.K_chunks / ke.K_super_size : 0; - ke.K_super_rem = (ke.K_chunks > 0) ? ke.K_chunks % ke.K_super_size : 0; - ke.batch_full = ke.K_super_size; - ke.batch_rem = ke.K_super_rem; - - const auto a_dt = dnnl::memory::data_type::bf16; - const auto b_dt = dnnl::memory::data_type::bf16; - const auto c_dt = dnnl::memory::data_type::f32; - ke.a_dt_size = dnnl::memory::data_type_size(a_dt); - ke.b_dt_size = dnnl::memory::data_type_size(b_dt); - - const auto pack = brgemm::get_B_pack_type(a_dt, b_dt); - if (pack == pack_type::undef) return; - ke.need_pack = (pack != pack_type::no_trans); - - ke.lda = A.Stride(); - ke.ldb_orig = B.Stride(); - - ke.m_sizes[0] = ke.M_blk; - ke.m_sizes[1] = ke.M_tail ? ke.M_tail : ke.M_blk; - ke.n_sizes[0] = ke.N_blk; - ke.n_sizes[1] = ke.N_tail ? ke.N_tail : ke.N_blk; - const int64_t ldb_for[2] = {static_cast(ke.N_blk), - static_cast(ke.N_tail ? ke.N_tail : ke.N_blk)}; - const int64_t ldc_for[2] = {static_cast(ke.N_blk), - static_cast(ke.N_tail ? ke.N_tail : ke.N_blk)}; - - // Create brgemm kernels for full/tail M and N tile sizes. - // mi=0 is the full M tile, mi=1 is the M-tail; likewise for ni and N. - size_t max_sp = 0; - for (int mi = 0; mi < 2; ++mi) { - for (int ni = 0; ni < 2; ++ni) { - if (mi == 1 && ke.M_tail == 0) continue; - if (ni == 1 && ke.N_tail == 0) continue; - if (mi == 0 && ke.M_full_tiles == 0) continue; - if (ni == 0 && ke.N_full_tiles == 0) continue; - - const int64_t ms = static_cast(ke.m_sizes[mi]); - const int64_t ns = static_cast(ke.n_sizes[ni]); - - if (ke.K_chunks > 0) { - if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns, - static_cast(ke.K_blk), - static_cast(ke.K_super_size), - static_cast(ke.lda), ldb_for[ni], - ldc_for[ni], a_dt, b_dt, c_dt, false)) { - return; - } - max_sp = std::max(max_sp, - ke.brg_first_all[mi][ni].get_scratchpad_size()); - } - if (ke.K_super_blocks > 1) { - if (!MakeBrgemm(ke.brg_full[mi][ni], ms, ns, - static_cast(ke.K_blk), - static_cast(ke.batch_full), - static_cast(ke.lda), ldb_for[ni], - ldc_for[ni], a_dt, b_dt, c_dt, true)) { - return; - } - max_sp = - std::max(max_sp, ke.brg_full[mi][ni].get_scratchpad_size()); - } - if (ke.K_super_rem > 0) { - const bool rem_is_first = (ke.K_super_blocks == 0); - auto& target = rem_is_first ? ke.brg_first_rem[mi][ni] - : ke.brg_rem[mi][ni]; - if (!MakeBrgemm(target, ms, ns, static_cast(ke.K_blk), - static_cast(ke.batch_rem), - static_cast(ke.lda), ldb_for[ni], - ldc_for[ni], a_dt, b_dt, c_dt, - !rem_is_first)) { - return; - } - max_sp = std::max(max_sp, target.get_scratchpad_size()); + if (ke.K_chunks > 0) { + if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns, + static_cast(ke.K_blk), + static_cast(ke.K_super_size), + static_cast(ke.lda), ldb_for[ni], + ldc_for[ni], a_dt, b_dt, c_dt, false)) { + return false; } - if (ke.K_tail > 0) { - const bool add_c = (ke.K_chunks > 0); - if (!MakeBrgemm(ke.brg_ktail[mi][ni], ms, ns, - static_cast(ke.K_tail), 1, - static_cast(ke.lda), ldb_for[ni], - ldc_for[ni], a_dt, b_dt, c_dt, - add_c)) { - return; - } - max_sp = - std::max(max_sp, ke.brg_ktail[mi][ni].get_scratchpad_size()); + max_sp = std::max(max_sp, + ke.brg_first_all[mi][ni].get_scratchpad_size()); + } + if (ke.K_super_blocks > 1) { + if (!MakeBrgemm(ke.brg_full[mi][ni], ms, ns, + static_cast(ke.K_blk), + static_cast(ke.batch_full), + static_cast(ke.lda), ldb_for[ni], + ldc_for[ni], a_dt, b_dt, c_dt, true)) { + return false; } + max_sp = + std::max(max_sp, ke.brg_full[mi][ni].get_scratchpad_size()); } - } - ke.scratchpad_size = max_sp + 64; - - // Create B-packing transforms. - if (ke.need_pack) { - for (int ni = 0; ni < 2; ++ni) { - if (ni == 1 && ke.N_tail == 0) continue; - if (ni == 0 && ke.N_full_tiles == 0) continue; - - const int64_t ns = static_cast(ke.n_sizes[ni]); - if (ke.K_chunks > 0) { - const int64_t K_full = - static_cast(ke.K_chunks * ke.K_blk); - try { - ke.pack_B[ni] = transform(K_full, ns, pack_type::trans, - static_cast(ke.ldb_orig), - ldb_for[ni], b_dt, b_dt); - if (!ke.pack_B[ni]) return; - ke.pack_B[ni].generate(); - ke.blocked_B_size[ni] = static_cast(ldb_for[ni]) * - ke.K_blk * ke.b_dt_size; - } catch (...) { - return; - } + if (ke.K_super_rem > 0) { + const bool rem_is_first = (ke.K_super_blocks == 0); + auto& target = rem_is_first ? ke.brg_first_rem[mi][ni] + : ke.brg_rem[mi][ni]; + if (!MakeBrgemm(target, ms, ns, static_cast(ke.K_blk), + static_cast(ke.batch_rem), + static_cast(ke.lda), ldb_for[ni], + ldc_for[ni], a_dt, b_dt, c_dt, + !rem_is_first)) { + return false; } - if (ke.K_tail > 0) { - try { - ke.pack_B_ktail[ni] = transform( - static_cast(ke.K_tail), ns, pack_type::trans, - static_cast(ke.ldb_orig), ldb_for[ni], b_dt, b_dt); - if (!ke.pack_B_ktail[ni]) return; - ke.pack_B_ktail[ni].generate(); - ke.blocked_B_ktail_size[ni] = - static_cast(ldb_for[ni]) * ke.K_tail * ke.b_dt_size; - } catch (...) { - return; - } + max_sp = std::max(max_sp, target.get_scratchpad_size()); + } + if (ke.K_tail > 0) { + const bool add_c = (ke.K_chunks > 0); + if (!MakeBrgemm(ke.brg_ktail[mi][ni], ms, ns, + static_cast(ke.K_tail), 1, + static_cast(ke.lda), ldb_for[ni], + ldc_for[ni], a_dt, b_dt, c_dt, + add_c)) { + return false; } + max_sp = + std::max(max_sp, ke.brg_ktail[mi][ni].get_scratchpad_size()); } } + } + ke.scratchpad_size = max_sp + 64; - // Precompute A/B offset tables for each K-super-block. + // Create B-packing transforms. + if (ke.need_pack) { for (int ni = 0; ni < 2; ++ni) { if (ni == 1 && ke.N_tail == 0) continue; if (ni == 0 && ke.N_full_tiles == 0) continue; - const size_t cur_n = ke.n_sizes[ni]; + const int64_t ns = static_cast(ke.n_sizes[ni]); if (ke.K_chunks > 0) { - ke.offsets_first_all[ni].resize(ke.K_super_size); - for (size_t i = 0; i < ke.K_super_size; ++i) { - const int64_t a_off = - static_cast(i * ke.K_blk * ke.a_dt_size); - const int64_t b_off = - ke.need_pack - ? static_cast(i * ke.blocked_B_size[ni]) - : static_cast(i * cur_n * ke.K_blk * ke.b_dt_size); - ke.offsets_first_all[ni][i] = {a_off, b_off}; + const int64_t K_full = + static_cast(ke.K_chunks * ke.K_blk); + try { + ke.pack_B[ni] = transform(K_full, ns, pack_type::trans, + static_cast(ke.ldb_orig), + ldb_for[ni], b_dt, b_dt); + if (!ke.pack_B[ni]) return false; + ke.pack_B[ni].generate(); + ke.blocked_B_size[ni] = static_cast(ldb_for[ni]) * + ke.K_blk * ke.b_dt_size; + } catch (...) { + return false; } } - - if (ke.K_super_blocks > 1) { - ke.offsets_full[ni].resize(ke.K_super_blocks - 1); - for (size_t ks = 1; ks < ke.K_super_blocks; ++ks) { - auto& tbl = ke.offsets_full[ni][ks - 1]; - tbl.resize(ke.batch_full); - const size_t k_start = ks * ke.K_super_size; - for (size_t i = 0; i < ke.batch_full; ++i) { - const size_t k_idx = k_start + i; - const int64_t a_off = - static_cast(k_idx * ke.K_blk * ke.a_dt_size); - const int64_t b_off = - ke.need_pack - ? static_cast(k_idx * ke.blocked_B_size[ni]) - : static_cast(k_idx * cur_n * ke.K_blk * - ke.b_dt_size); - tbl[i] = {a_off, b_off}; - } + if (ke.K_tail > 0) { + try { + ke.pack_B_ktail[ni] = transform( + static_cast(ke.K_tail), ns, pack_type::trans, + static_cast(ke.ldb_orig), ldb_for[ni], b_dt, b_dt); + if (!ke.pack_B_ktail[ni]) return false; + ke.pack_B_ktail[ni].generate(); + ke.blocked_B_ktail_size[ni] = + static_cast(ldb_for[ni]) * ke.K_tail * ke.b_dt_size; + } catch (...) { + return false; } } + } + } - if (ke.K_super_rem > 0) { - const size_t k_base = ke.K_super_blocks * ke.K_super_size; - auto& rem_tbl = (ke.K_super_blocks == 0) ? ke.offsets_first_rem[ni] - : ke.offsets_rem[ni]; - rem_tbl.resize(ke.K_super_rem); - for (size_t i = 0; i < ke.K_super_rem; ++i) { - const size_t k_idx = k_base + i; + // Precompute A/B offset tables for each K-super-block. + for (int ni = 0; ni < 2; ++ni) { + if (ni == 1 && ke.N_tail == 0) continue; + if (ni == 0 && ke.N_full_tiles == 0) continue; + const size_t cur_n = ke.n_sizes[ni]; + + if (ke.K_chunks > 0) { + ke.offsets_first_all[ni].resize(ke.K_super_size); + for (size_t i = 0; i < ke.K_super_size; ++i) { + const int64_t a_off = + static_cast(i * ke.K_blk * ke.a_dt_size); + const int64_t b_off = + ke.need_pack + ? static_cast(i * ke.blocked_B_size[ni]) + : static_cast(i * cur_n * ke.K_blk * ke.b_dt_size); + ke.offsets_first_all[ni][i] = {a_off, b_off}; + } + } + + if (ke.K_super_blocks > 1) { + ke.offsets_full[ni].resize(ke.K_super_blocks - 1); + for (size_t ks = 1; ks < ke.K_super_blocks; ++ks) { + auto& tbl = ke.offsets_full[ni][ks - 1]; + tbl.resize(ke.batch_full); + const size_t k_start = ks * ke.K_super_size; + for (size_t i = 0; i < ke.batch_full; ++i) { + const size_t k_idx = k_start + i; const int64_t a_off = static_cast(k_idx * ke.K_blk * ke.a_dt_size); const int64_t b_off = @@ -291,16 +267,55 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( ? static_cast(k_idx * ke.blocked_B_size[ni]) : static_cast(k_idx * cur_n * ke.K_blk * ke.b_dt_size); - rem_tbl[i] = {a_off, b_off}; + tbl[i] = {a_off, b_off}; } } } + if (ke.K_super_rem > 0) { + const size_t k_base = ke.K_super_blocks * ke.K_super_size; + auto& rem_tbl = (ke.K_super_blocks == 0) ? ke.offsets_first_rem[ni] + : ke.offsets_rem[ni]; + rem_tbl.resize(ke.K_super_rem); + for (size_t i = 0; i < ke.K_super_rem; ++i) { + const size_t k_idx = k_base + i; + const int64_t a_off = + static_cast(k_idx * ke.K_blk * ke.a_dt_size); + const int64_t b_off = + ke.need_pack + ? static_cast(k_idx * ke.blocked_B_size[ni]) + : static_cast(k_idx * cur_n * ke.K_blk * + ke.b_dt_size); + rem_tbl[i] = {a_off, b_off}; + } + } + } + + return true; +} + +template +static HWY_NOINLINE void DoMatMul_BRGeMM( + const MatPtrT& A, const MatPtrT& B, RowPtrs C, size_t M, + size_t K, size_t N, float scale, const float* HWY_RESTRICT add, + const BRGeMMConfig& cfg, ThreadingContext& ctx, size_t cluster_idx) { + using dnnl::ukernel::brgemm; + + // Level-1 cache: kernels keyed on (M, K, N, config). + const BRGeMMKernelKey kern_key{M, K, N, cfg.M_blk, cfg.N_blk, cfg.K_blk, + cfg.batch_size}; + auto& kern_cache = GetBRGeMMKernelCache(); + auto kern_it = kern_cache.find(kern_key); + + if (kern_it == kern_cache.end()) { + BRGeMMKernelEntry ke; + if (!InitBRGeMMKernels(cfg, M, K, N, A.Stride(), B.Stride(), ke)) { + return; + } kern_it = kern_cache.emplace(kern_key, std::move(ke)).first; } BRGeMMKernelEntry& ke = kern_it->second; - if (ke.M_total_tiles == 0 || ke.N_total_tiles == 0) return; // Level-2 cache: packed B keyed on (B_ptr, K, N, config). const uint8_t* A_base = reinterpret_cast(A.Row(0)); From acf75926130a79fa292c32c0311f9c954a5281a8 Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Wed, 6 May 2026 00:54:43 +0000 Subject: [PATCH 10/11] Added HWY_WARN and fallback instead of exiting --- ops/brgemm-inl.h | 30 ++++++++++++++++++++++++------ ops/matmul-inl.h | 38 +++++++++++++++++++++----------------- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h index 18dadfc0..20c810cc 100644 --- a/ops/brgemm-inl.h +++ b/ops/brgemm-inl.h @@ -54,12 +54,25 @@ static bool MakeBrgemm(dnnl::ukernel::brgemm& brg, int64_t m, int64_t n, try { brg = dnnl::ukernel::brgemm(m, n, k, batch, lda, ldb, ldc, a_dt, b_dt, c_dt, true); - if (!brg) return false; + if (!brg) { + HWY_WARN("BRGeMM: kernel creation failed m=%lld n=%lld k=%lld.", + static_cast(m), static_cast(n), + static_cast(k)); + return false; + } brg.set_add_C(add_C); - if (!brg.finalize()) return false; + if (!brg.finalize()) { + HWY_WARN("BRGeMM: kernel finalize failed m=%lld n=%lld k=%lld.", + static_cast(m), static_cast(n), + static_cast(k)); + return false; + } brg.generate(); return true; } catch (...) { + HWY_WARN("BRGeMM: kernel JIT exception m=%lld n=%lld k=%lld.", + static_cast(m), static_cast(n), + static_cast(k)); return false; } } @@ -295,7 +308,7 @@ static HWY_NOINLINE bool InitBRGeMMKernels( } template -static HWY_NOINLINE void DoMatMul_BRGeMM( +static HWY_NOINLINE bool DoMatMul_BRGeMM( const MatPtrT& A, const MatPtrT& B, RowPtrs C, size_t M, size_t K, size_t N, float scale, const float* HWY_RESTRICT add, const BRGeMMConfig& cfg, ThreadingContext& ctx, size_t cluster_idx) { @@ -310,7 +323,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( if (kern_it == kern_cache.end()) { BRGeMMKernelEntry ke; if (!InitBRGeMMKernels(cfg, M, K, N, A.Stride(), B.Stride(), ke)) { - return; + return false; } kern_it = kern_cache.emplace(kern_key, std::move(ke)).first; } @@ -344,7 +357,10 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( pe.B_packed_buf.Resize(total_packed); uint8_t* B_packed = pe.B_packed_buf.data(); - if (!B_packed) return; + if (!B_packed) { + HWY_WARN("BRGeMM: packed B allocation failed."); + return false; + } for (size_t nt = 0; nt < ke.N_total_tiles; ++nt) { const int ni = (nt < ke.N_full_tiles) ? 0 : 1; @@ -366,7 +382,8 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( B_packed + pe.B_ktail_offset[nt]); } } catch (...) { - return; + HWY_WARN("BRGeMM: B-packing execution failed."); + return false; } } } @@ -548,6 +565,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM( dnnl::ukernel::brgemm::release_hw_context(); auto& main_bufs = GetBRGeMMThreadBufs(); main_bufs.hw_ctx_kernel = nullptr; + return true; } #endif // GEMMA_ONEDNN_BRGEMM diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index ce023f42..869d315e 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1089,9 +1089,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMAutoTune& brg_tuner = per_key.brgemm_autotune; if (HWY_LIKELY(brg_tuner.Best())) { - DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, *brg_tuner.Best(), - env.ctx, cluster_idx); - return &per_key; + if (DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, + *brg_tuner.Best(), env.ctx, cluster_idx)) { + return &per_key; + } + // BRGeMM failed; fall through to standard matmul. } if (HWY_UNLIKELY(!brg_tuner.HasCandidates())) { @@ -1100,21 +1102,23 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const BRGeMMConfig& cfg = brg_tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, cfg, env.ctx, - cluster_idx); - const uint64_t t1 = - env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); - brg_tuner.NotifyTicks(t1 - t0); - - if (HWY_UNLIKELY(env.print_best && brg_tuner.Best())) { - const BRGeMMConfig& best = *brg_tuner.Best(); - fprintf(stderr, - "BRGeMM best: %zux%zux%zu M_blk=%zu N_blk=%zu K_blk=%zu " - "batch=%zu\n", - M, K, N, best.M_blk, best.N_blk, best.K_blk, - best.batch_size); + if (DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, cfg, env.ctx, + cluster_idx)) { + const uint64_t t1 = + env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + brg_tuner.NotifyTicks(t1 - t0); + + if (HWY_UNLIKELY(env.print_best && brg_tuner.Best())) { + const BRGeMMConfig& best = *brg_tuner.Best(); + fprintf(stderr, + "BRGeMM best: %zux%zux%zu M_blk=%zu N_blk=%zu K_blk=%zu " + "batch=%zu\n", + M, K, N, best.M_blk, best.N_blk, best.K_blk, + best.batch_size); + } + return &per_key; } - return &per_key; + // BRGeMM failed; fall through to standard matmul. } } // if constexpr BF16/float #endif // GEMMA_ONEDNN_BRGEMM From 7bdf4c6109ab22bf1757f077117182b4ad0a0b71 Mon Sep 17 00:00:00 2001 From: Bibek Bhattarai Date: Wed, 6 May 2026 02:05:09 +0000 Subject: [PATCH 11/11] using hwy::AlignedVector instead of std::vector for scratch and tc_storage --- ops/brgemm.h | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/ops/brgemm.h b/ops/brgemm.h index a9d0f231..fd6279e9 100644 --- a/ops/brgemm.h +++ b/ops/brgemm.h @@ -27,6 +27,7 @@ #include #include +#include "hwy/aligned_allocator.h" #include "hwy/base.h" #if GEMMA_ONEDNN_BRGEMM @@ -238,22 +239,19 @@ struct BRGeMMPackedBEntry { struct BRGeMMThreadBufs { static constexpr size_t kMaxTempCSize = 64 * 64; - std::vector scratch; - std::vector tc_storage; + hwy::AlignedVector scratch; + hwy::AlignedVector tc_storage; const void* hw_ctx_kernel = nullptr; uint8_t* EnsureScratch(size_t size) { - if (scratch.size() < size + 64) scratch.resize(size + 64); - return scratch.data() + - (64 - (reinterpret_cast(scratch.data()) % 64)); + if (scratch.size() < size) scratch.resize(size); + return scratch.data(); } float* EnsureTempC(size_t n_tiles) { - const size_t need = n_tiles * kMaxTempCSize * sizeof(float) + 64; + const size_t need = n_tiles * kMaxTempCSize * sizeof(float); if (tc_storage.size() < need) tc_storage.resize(need); - return reinterpret_cast( - (reinterpret_cast(tc_storage.data()) + 63) & - ~uintptr_t{63}); + return reinterpret_cast(tc_storage.data()); } void MaybeSetHwContext(const dnnl::ukernel::brgemm& brg) {