From 163f6149967fc93bde668132bfd06ff43002c1cb Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 23 Mar 2026 16:46:10 +0800 Subject: [PATCH 01/17] [CPU/CUDA EP] Add DeformConv op support (#27393) ### Description This change adds support for the Deformable Convolution 2D operator (DeformConv2D) to ONNX Runtime. The branch implements the operator schema and registration, provides kernel implementations (CPU and GPU/CUDA where available), implements shape inference, and adds unit and integration tests to validate correctness and numerical parity with reference implementations. The changes include performance-oriented optimizations and necessary changes to build/test scripts. ### Motivation and Context Deformable convolutions are widely used in vision models that require spatial sampling flexibility (e.g., Deformable ConvNets, some detection/segmentation models). Native support in ONNX Runtime enables these models to run efficiently without custom operators or external runtimes, broadening the set of compatible models and improving performance and portability. ### See also - https://onnx.ai/onnx/operators/onnx__DeformConv.html - https://docs.pytorch.org/vision/main/generated/torchvision.ops.deform_conv2d.html - https://arxiv.org/abs/1811.11168 - https://arxiv.org/abs/1703.06211 - https://github.com/pytorch/vision/blob/0f6d91d9fe514e6de2f5519114cbeb389d498b2d/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu - https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp - https://github.com/pytorch/vision/blob/0f6d91d9fe514e6de2f5519114cbeb389d498b2d/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp - https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp - #22060 - #15572 - #20810 - #16903 - https://github.com/onnx/onnx/issues/5451 - https://github.com/ZhengPeng7/BiRefNet/pull/167 - https://github.com/pytorch/pytorch/issues/68910 - https://github.com/pytorch/vision/issues/2066 --- docs/OperatorKernels.md | 4 + .../providers/cpu/cpu_execution_provider.cc | 8 + .../core/providers/cpu/nn/deform_conv.cc | 321 ++++++ .../core/providers/cpu/nn/deform_conv.h | 24 + .../providers/cpu/nn/deform_conv_attributes.h | 198 ++++ .../providers/cuda/cuda_execution_provider.cc | 14 + .../core/providers/cuda/nn/deform_conv.cc | 331 ++++++ .../core/providers/cuda/nn/deform_conv.h | 27 + .../providers/cuda/nn/deform_conv_impl.cu | 512 ++++++++++ .../core/providers/cuda/nn/deform_conv_impl.h | 63 ++ .../cpu/nn/deform_conv_expected_gen.py | 180 ++++ .../providers/cpu/nn/deform_conv_op_test.cc | 948 ++++++++++++++++++ .../test/testdata/deform_conv_test.onnx | Bin 0 -> 353 bytes .../test/testdata/deform_conv_test_data.inc | 10 + .../test/testdata/deform_conv_test_data.npz | Bin 0 -> 2092 bytes .../test/testdata/nn/deform_conv_test_gen.py | 194 ++++ 16 files changed, 2834 insertions(+) create mode 100644 onnxruntime/core/providers/cpu/nn/deform_conv.cc create mode 100644 onnxruntime/core/providers/cpu/nn/deform_conv.h create mode 100644 onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv.cc create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv.h create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv_impl.h create mode 100644 onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py create mode 100644 onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc create mode 100644 onnxruntime/test/testdata/deform_conv_test.onnx create mode 100644 onnxruntime/test/testdata/deform_conv_test_data.inc create mode 100644 onnxruntime/test/testdata/deform_conv_test_data.npz create mode 100644 onnxruntime/test/testdata/nn/deform_conv_test_gen.py diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 39c9145a40912..25829d08206cf 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -103,6 +103,8 @@ Do not modify directly.* |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |||[17, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DeformConv|*in* X:**T**
*in* W:**T**
*in* offset:**T**
*in* B:**T**
*in* mask:**T**
*out* Y:**T**|22+|**T** = tensor(double), tensor(float)| +|||[19, 21]|**T** = tensor(double), tensor(float)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(uint8)| |||[1, 10]|**T** = tensor(double), tensor(float)| @@ -697,6 +699,8 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| +|DeformConv|*in* X:**T**
*in* W:**T**
*in* offset:**T**
*in* B:**T**
*in* mask:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 74b8f8e468097..9f19a20a2e680 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1220,6 +1220,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv); // Opset 20 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, 20, ConstantOfShape); @@ -1316,6 +1318,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Ac class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Atanh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, ConvTranspose); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, double, DeformConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Det); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_float, Dropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_double, Dropout); @@ -3277,6 +3281,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Resize)>, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 20 BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc new file mode 100644 index 0000000000000..f128b0e0182ad --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -0,0 +1,321 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CPU implementation of DeformConv (deformable convolution 2D). + +#include "deform_conv.h" + +#include + +#include "core/common/common.h" +#include "core/util/math_cpuonly.h" +#include "core/common/narrow.h" +#include "core/util/math.h" + +namespace onnxruntime { + +namespace { +// Bilinear interpolation at (h, w). Out-of-bounds samples return 0 (ONNX spec). +// Indices use int (not int64_t) to reduce register pressure and improve occupancy in the hot path. +// Limitation: height and width must not exceed INT_MAX, or casting floor(h)/floor(w) to int may overflow. +// Acceptable in practice: deformable convolution spatial dimensions are typically well below INT_MAX. +template +T BilinearInterpolate(const T* in, int height, int width, T h, T w) { + // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case). + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return static_cast(0); + } + + // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw. + const T h_floor = std::floor(h); + const T w_floor = std::floor(w); + const int h_low = static_cast(h_floor); + const int w_low = static_cast(w_floor); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const T lh = h - h_floor; + const T lw = w - w_floor; + const T hh = static_cast(1) - lh; + const T hw = static_cast(1) - lw; + + // Fast path: all 4 corners in bounds (h in [0, height-1), w in [0, width-1)). + // Most sampling points in deformable conv fall here; avoids 4 per-corner branches. + // [Optimization 3]: Use unsigned comparison to avoid branch on negative height/width. + if (static_cast(h_low) < static_cast(height - 1) && + static_cast(w_low) < static_cast(width - 1)) { + const int base_low = h_low * width; + const int base_high = h_high * width; + return hh * hw * in[base_low + w_low] + + hh * lw * in[base_low + w_high] + + lh * hw * in[base_high + w_low] + + lh * lw * in[base_high + w_high]; + } + + // Slow path: near boundary (one or more of the 4 corners may be out of bounds). + const int base_low = h_low * width; + const int base_high = h_high * width; + const T v1 = (h_low >= 0 && w_low >= 0) ? in[base_low + w_low] : static_cast(0); + const T v2 = (h_low >= 0 && w_high < width) ? in[base_low + w_high] : static_cast(0); + const T v3 = (h_high < height && w_low >= 0) ? in[base_high + w_low] : static_cast(0); + const T v4 = (h_high < height && w_high < width) ? in[base_high + w_high] : static_cast(0); + return hh * hw * v1 + hh * lw * v2 + lh * hw * v3 + lh * lw * v4; +} + +// Deformable Im2Col for a SINGLE image. +// Converts the input image into a matrix suitable for GEMM by sampling with learned offsets. +// Output 'data_col' shape: [C_in * kH * kW, H_out * W_out] +// When UseMask=false, pass nullptr for data_mask; compiler eliminates dead code for mask. +template +void DeformableIm2col( + const T* data_im, // Input image [C, H, W] + const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out] + const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (nullptr when UseMask=false) + int height, int width, // Input spatial dimensions (validated H*W <= INT_MAX) + int64_t kernel_h, int64_t kernel_w, // Kernel dimensions + int64_t pad_h, int64_t pad_w, // Padding (begin) for H and W + int64_t stride_h, int64_t stride_w, // Stride for H and W + int64_t dilation_h, int64_t dilation_w, // Dilation for H and W + int64_t channels, // Input channels + int64_t offset_groups, // Number of offset groups (channels shared per group) + int64_t height_col, int64_t width_col, // Output spatial dimensions (H_out, W_out) + T* data_col, // Output buffer for im2col result + concurrency::ThreadPool* thread_pool) { + const int64_t channel_per_offset_group = channels / offset_groups; + const int64_t kernel_size = kernel_h * kernel_w; + const int64_t output_size = height_col * width_col; + + // Parallelize over (channel, kernel_position) so each task processes one full row of data_col. + // This yields channels*kernel_size tasks, better CPU utilization and cache-friendly sequential writes. + concurrency::ThreadPool::TryParallelFor( + thread_pool, + static_cast(channels * kernel_size), + static_cast(output_size) * 10.0, + [&](ptrdiff_t begin, ptrdiff_t end) { + for (ptrdiff_t idx = begin; idx < end; ++idx) { + // Decompose idx into (c_im, i, j): which channel and kernel position. + const int64_t j = static_cast(idx) % kernel_w; + const int64_t i = (static_cast(idx) / kernel_w) % kernel_h; + const int64_t c_im = static_cast(idx) / kernel_size; + const int64_t offset_grp = c_im / channel_per_offset_group; + + // Output row: one (channel, kernel_pos) across all spatial locations. + T* col_ptr = data_col + static_cast(idx) * output_size; + const T* im_ptr = data_im + c_im * static_cast(height) * width; + + // Offset tensor layout: [offset_grp, 2*kH*kW, H_out, W_out] flattened. + // For (i,j) we use channel indices 2*(i*kW+j) and 2*(i*kW+j)+1 for offset_h, offset_w. + // Precompute pointers to avoid offset_base * output_size multiplication in inner loop. + const int64_t offset_base = + offset_grp * 2 * kernel_size + 2 * (i * kernel_w + j); + const T* ptr_offset_h = data_offset + offset_base * output_size; + const T* ptr_offset_w = data_offset + (offset_base + 1) * output_size; + + // Base terms for h_im, w_im: invariant in inner loop (i, j fixed). + const T base_h = -pad_h + static_cast(i) * dilation_h; + const T base_w = -pad_w + static_cast(j) * dilation_w; + + // Mask pointer; only used when UseMask=true (compiler removes when false). + [[maybe_unused]] const T* ptr_mask = nullptr; + if constexpr (UseMask) { + ptr_mask = data_mask + (offset_grp * kernel_size + i * kernel_w + j) * output_size; + } + + // Loop over output spatial positions. + for (int64_t h_col = 0; h_col < height_col; ++h_col) { + for (int64_t w_col = 0; w_col < width_col; ++w_col) { + const int64_t spatial_idx = h_col * width_col + w_col; + + const T offset_h = ptr_offset_h[spatial_idx]; + const T offset_w = ptr_offset_w[spatial_idx]; + + // Deformed sampling coordinates (fractional, for bilinear interpolation). + const T h_im = h_col * stride_h + base_h + offset_h; + const T w_im = w_col * stride_w + base_w + offset_w; + + // Sample input at deformed location; returns 0 if out of bounds. + T val = BilinearInterpolate(im_ptr, height, width, h_im, w_im); + + // Modulate by mask when UseMask=true; compiled away when false. + // Design choice: we always interpolate then multiply, rather than skip when mask==0. + // Rationale: (1) Skipping adds a branch; unpredictable mask values cause misprediction + // penalties (~15-20 cycles). (2) Straight-line code vectorizes better; conditional + // skip blocks SIMD. (3) Multiplying by 0 is cheap when vectorized. In typical DCN + // usage (moderate mask density), the unconditional path usually wins. + if constexpr (UseMask) { + val *= ptr_mask[spatial_idx]; + } + + col_ptr[spatial_idx] = val; + } + } + } + }); +} + +} // namespace + +template +Status DeformConv::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto* offset = context->Input(2); + const auto* B = context->Input(3); // optional + const auto* mask = context->Input(4); // optional + + DeformConvParams params; + ORT_RETURN_IF_ERROR(DeformConvValidateAndParse( + attrs_, + X->Shape(), + W->Shape(), + offset->Shape(), + B ? &B->Shape() : nullptr, + mask ? &mask->Shape() : nullptr, + params)); + + const int64_t N = params.N; + const int64_t C = params.C; + const int64_t H = params.H; + const int64_t W_in = params.W_in; + const int64_t M = params.M; + const int64_t kH = params.kH; + const int64_t kW = params.kW; + const int64_t pad_h = params.pad_h; + const int64_t pad_w = params.pad_w; + const int64_t stride_h = params.stride_h; + const int64_t stride_w = params.stride_w; + const int64_t dilation_h = params.dilation_h; + const int64_t dilation_w = params.dilation_w; + const int64_t group = params.group; + const int64_t offset_group = params.offset_group; + const int64_t out_h = params.out_h; + const int64_t out_w = params.out_w; + const bool use_mask = params.use_mask; + + // Allocate output tensor [N, M, out_h, out_w]. + const TensorShape Y_shape({N, M, out_h, out_w}); + Tensor* Y = context->Output(0, Y_shape); + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + // Precompute common sizes for the im2col + GEMM pipeline. + const int64_t kernel_size = kH * kW; + const int64_t output_image_size = out_h * out_w; + const int64_t input_image_size = H * W_in; + const int64_t kernel_dim = C / group * kernel_size; // K dimension for GEMM: C/group * kH * kW + + // Col buffer: shape [C*kH*kW, out_h*out_w]. Allocate per-image (process one image at a time) + // to reduce peak memory when N is large; im2col is implemented per-image anyway. + const int64_t col_buffer_size = (C * kernel_size) * output_image_size; + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size)); + + const T* Xdata = X->Data(); + const T* Wdata = W->Data(); + const T* offset_data = offset->Data(); + const T* mask_data = use_mask ? mask->Data() : nullptr; + T* Ydata = Y->MutableData(); + const T* Bdata = (B != nullptr) ? B->Data() : nullptr; + + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + + // Process each image in the batch. + for (int64_t n = 0; n < N; ++n) { + // Step 1: Deformable Im2Col for image n. + // Gather deformed samples into col buffer for GEMM. + const T* X_curr = Xdata + n * (C * input_image_size); + const T* offset_curr = offset_data + n * (offset_group * 2 * kernel_size * output_image_size); + const T* mask_curr = use_mask ? (mask_data + n * (offset_group * kernel_size * output_image_size)) : nullptr; + T* col_buffer_ptr = col_buffer.get(); + + // Dispatch to template instantiation: UseMask=true or false eliminates branch in hot loop. + // Note: pad_h, pad_w are begin-side paddings for coordinate mapping; pad_h_end/pad_w_end + // affect only output size (already baked into out_h, out_w), not im2col sampling. + if (use_mask) { + DeformableIm2col( + X_curr, offset_curr, mask_curr, + static_cast(H), static_cast(W_in), kH, kW, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + C, offset_group, out_h, out_w, + col_buffer_ptr, thread_pool); + } else { + DeformableIm2col( + X_curr, offset_curr, nullptr, + static_cast(H), static_cast(W_in), kH, kW, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + C, offset_group, out_h, out_w, + col_buffer_ptr, thread_pool); + } + + // Step 2: GEMM for each group. Y = W * Col (per group). + for (int64_t g = 0; g < group; ++g) { + // Weight for group g: shape [M/group, C/group, kH, kW], row-major. + const T* weight_g = Wdata + g * (M / group) * kernel_dim; + + // Col rows for group g: layout [C*kH*kW, out_h*out_w], group g spans rows [g*kernel_dim, (g+1)*kernel_dim). + const T* col_g = col_buffer_ptr + g * kernel_dim * output_image_size; + + // Output slice for group g: [n, g*M/group:(g+1)*M/group, out_h, out_w]. + T* Y_g = Ydata + n * M * output_image_size + g * (M / group) * output_image_size; + + // GEMM: Y = W * Col. W [M/group, kernel_dim], Col [kernel_dim, output_image_size]. + math::Gemm( + CblasNoTrans, + CblasNoTrans, + narrow(M / group), // M + narrow(output_image_size), // N + narrow(kernel_dim), // K + static_cast(1), // alpha + weight_g, // A + col_g, // B + static_cast(0), // beta + Y_g, // C + thread_pool, + nullptr); // mlas_backend_kernel_selector_config + } + } + + // Step 3: Add bias if provided (broadcast over spatial dimensions). + if (Bdata != nullptr) { + int64_t total_work = N * M; + concurrency::ThreadPool::TryParallelFor( + thread_pool, static_cast(total_work), static_cast(output_image_size), + [&](ptrdiff_t first, ptrdiff_t last) { + for (ptrdiff_t idx = first; idx < last; ++idx) { + int64_t n = idx / M; + int64_t m = idx % M; + T* Y_ptr = Ydata + n * M * output_image_size + m * output_image_size; + // Eigen vectorized add: Y_ptr += Bdata[m] over all spatial positions. + EigenVectorArrayMap(Y_ptr, narrow(output_image_size)) += Bdata[m]; + } + }); + } + + return Status::OK(); +} + +// Explicit template instantiation for float and double +template class DeformConv; +template class DeformConv; + +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + DeformConv, 19, 21, T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv); \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + DeformConv, 22, T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv) + +REGISTER_DEFORMCONV_KERNEL_TYPED(float) +REGISTER_DEFORMCONV_KERNEL_TYPED(double) + +#undef REGISTER_DEFORMCONV_KERNEL_TYPED + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.h b/onnxruntime/core/providers/cpu/nn/deform_conv.h new file mode 100644 index 0000000000000..c8d7763e58bcb --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/framework/op_node_proto_helper.h" +#include "deform_conv_attributes.h" + +namespace onnxruntime { + +template +class DeformConv : public OpKernel { + public: + explicit DeformConv(const OpKernelInfo& info) : OpKernel(info), attrs_(info) {} + + Status Compute(OpKernelContext* context) const override; + + private: + DeformConvAttributes attrs_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h new file mode 100644 index 0000000000000..8bc891bb4f377 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensor_shape.h" + +namespace onnxruntime { + +// Shared attributes for ONNX DeformConv (opset 19+). +// See https://onnx.ai/onnx/operators/onnx__DeformConv.html +// Used by both CPU and CUDA implementations (CUDA includes from here). +struct DeformConvAttributes { + explicit DeformConvAttributes(const OpKernelInfo& info) { + // Optional attributes. + // If not present, they will be empty/default, and handled in Compute/ComputeInternal. + (void)info.GetAttrs("kernel_shape", kernel_shape); + (void)info.GetAttrs("strides", strides); + (void)info.GetAttrs("pads", pads); + (void)info.GetAttrs("dilations", dilations); + group = info.GetAttrOrDefault("group", 1); + offset_group = info.GetAttrOrDefault("offset_group", 1); + } + + TensorShapeVector kernel_shape; + TensorShapeVector strides; + TensorShapeVector pads; + TensorShapeVector dilations; + int64_t group{1}; + int64_t offset_group{1}; +}; + +// Parsed and validated parameters from DeformConv inputs. +// Used by both CPU and CUDA implementations. +// Field names align with ONNX DeformConv spec: https://onnx.ai/onnx/operators/onnx__DeformConv.html +struct DeformConvParams { + // Input X shape (N, C, H, W) + int64_t N{0}; // Batch size + int64_t C{0}; // Number of input channels + int64_t H{0}; // Input height + int64_t W_in{0}; // Input width (W_in to avoid collision with weight W) + + // Weight W shape (oC, C/group, kH, kW) + int64_t M{0}; // Number of output channels (oC) + int64_t kH{0}; // Kernel height + int64_t kW{0}; // Kernel width + + // Pads [x1_begin, x2_begin, x1_end, x2_end] for spatial axes H, W + int64_t pad_h{0}; + int64_t pad_w{0}; + int64_t pad_h_end{0}; + int64_t pad_w_end{0}; + + // Strides and dilations along each spatial axis (default 1) + int64_t stride_h{1}; + int64_t stride_w{1}; + int64_t dilation_h{1}; + int64_t dilation_w{1}; + + // Attributes: C and oC must be divisible by group; C must be divisible by offset_group + int64_t group{1}; // Number of groups for input/output channels + int64_t offset_group{1}; // Number of groups of offset + + // Output Y shape (N, oC, oH, oW) + int64_t out_h{0}; // Output height (oH) + int64_t out_w{0}; // Output width (oW) + + bool use_mask{false}; // Whether optional mask input is provided +}; + +// Validates inputs and parses attributes into params. +// Returns Status::OK() on success; on failure, params may be partially filled. +inline Status DeformConvValidateAndParse( + const DeformConvAttributes& attrs, + const TensorShape& X_shape, + const TensorShape& W_shape, + const TensorShape& offset_shape, + const TensorShape* B_shape, + const TensorShape* mask_shape, + DeformConvParams& params) { + ORT_RETURN_IF_NOT(X_shape.NumDimensions() == 4, "Input X must be 4D (N, C, H, W)."); + ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); + ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); + + // Parse input shapes + params.N = X_shape[0]; + params.C = X_shape[1]; + params.H = X_shape[2]; + params.W_in = X_shape[3]; + params.M = W_shape[0]; + ORT_RETURN_IF_NOT(params.N >= 0, "Batch size N must be non-negative."); + ORT_RETURN_IF_NOT(params.C > 0, "Input channels C must be positive."); + ORT_RETURN_IF_NOT(params.M > 0, "Output channels M (oC) must be positive."); + ORT_RETURN_IF_NOT(W_shape[1] > 0, "Weight W must have positive in-channels (W_shape[1] = C/group)."); + + // Handle kernel shape inference. If kernel_shape is provided, it must match weight spatial dims + // to avoid GEMM using wrong K and potential out-of-bounds reads from the weight buffer. + const int64_t W_kH = W_shape[2]; + const int64_t W_kW = W_shape[3]; + if (!attrs.kernel_shape.empty()) { + ORT_RETURN_IF_NOT(attrs.kernel_shape.size() == 2, + "kernel_shape must be absent or have exactly 2 values (kH, kW) for 2D DeformConv."); + ORT_RETURN_IF_NOT(attrs.kernel_shape[0] == W_kH && attrs.kernel_shape[1] == W_kW, + "kernel_shape must match weight spatial dimensions (W_shape[2], W_shape[3])."); + params.kH = attrs.kernel_shape[0]; + params.kW = attrs.kernel_shape[1]; + } else { + params.kH = W_kH; + params.kW = W_kW; + } + + // DeformConv is 2D-only: when an attribute is present, require exact length to avoid silently misinterpreting malformed models. + params.pad_h = params.pad_w = params.pad_h_end = params.pad_w_end = 0; + if (!attrs.pads.empty()) { + ORT_RETURN_IF_NOT(attrs.pads.size() == 4, + "pads must be absent or have exactly 4 values [pad_h_begin, pad_w_begin, pad_h_end, pad_w_end] for 2D DeformConv."); + params.pad_h = attrs.pads[0]; + params.pad_w = attrs.pads[1]; + params.pad_h_end = attrs.pads[2]; + params.pad_w_end = attrs.pads[3]; + ORT_RETURN_IF_NOT(params.pad_h >= 0 && params.pad_w >= 0 && params.pad_h_end >= 0 && params.pad_w_end >= 0, + "Pads must be non-negative (ONNX spec)."); + } + + if (!attrs.strides.empty()) { + ORT_RETURN_IF_NOT(attrs.strides.size() == 2, + "strides must be absent or have exactly 2 values [stride_h, stride_w] for 2D DeformConv."); + params.stride_h = attrs.strides[0]; + params.stride_w = attrs.strides[1]; + } else { + params.stride_h = params.stride_w = 1; + } + + if (!attrs.dilations.empty()) { + ORT_RETURN_IF_NOT(attrs.dilations.size() == 2, + "dilations must be absent or have exactly 2 values [dilation_h, dilation_w] for 2D DeformConv."); + params.dilation_h = attrs.dilations[0]; + params.dilation_w = attrs.dilations[1]; + } else { + params.dilation_h = params.dilation_w = 1; + } + params.group = attrs.group; + params.offset_group = attrs.offset_group; + params.use_mask = (mask_shape != nullptr); + + // Validate attributes + ORT_RETURN_IF_NOT(params.stride_h > 0 && params.stride_w > 0, "Strides must be positive."); + ORT_RETURN_IF_NOT(params.dilation_h > 0 && params.dilation_w > 0, "Dilations must be positive."); + ORT_RETURN_IF_NOT(params.kH > 0 && params.kW > 0, "Kernel shape must be positive."); + ORT_RETURN_IF_NOT(params.group > 0, "group must be positive"); + ORT_RETURN_IF_NOT(params.offset_group > 0, "offset_group must be positive"); + + params.out_h = (params.H + params.pad_h + params.pad_h_end - params.dilation_h * (params.kH - 1) - 1) / params.stride_h + 1; + params.out_w = (params.W_in + params.pad_w + params.pad_w_end - params.dilation_w * (params.kW - 1) - 1) / params.stride_w + 1; + ORT_RETURN_IF_NOT(params.out_h >= 0 && params.out_w >= 0, "Computed output spatial size must be non-negative."); + + // CPU BilinearInterpolate uses int for indices (for performance optimization); W <= INT_MAX / (H+1) covers all index math. + ORT_RETURN_IF_NOT(params.H >= 0 && params.W_in >= 0, "Input spatial dimensions H and W must be non-negative."); + ORT_RETURN_IF_NOT(params.W_in <= static_cast(INT_MAX) / (params.H + 1), + "Input (H+1)*W must not exceed INT_MAX (for performance optimization)."); + + // Validate tensor shapes (use division to avoid int64 overflow in offset_group * 2 * kH * kW). + ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size."); + const int64_t offset_block = 2 * params.kH * params.kW; + ORT_RETURN_IF_NOT(offset_block > 0 && offset_shape[1] % offset_block == 0 && + offset_shape[1] / offset_block == params.offset_group, + "Offset channel count must be offset_group * 2 * kH * kW."); + ORT_RETURN_IF_NOT(offset_shape[2] == params.out_h, "Offset spatial height must match output oH."); + ORT_RETURN_IF_NOT(offset_shape[3] == params.out_w, "Offset spatial width must match output oW."); + ORT_RETURN_IF_NOT(params.C % params.offset_group == 0, "Input channels must be divisible by offset_group."); + ORT_RETURN_IF_NOT(params.C == W_shape[1] * params.group, "Input channels must match weight in channels * group."); + ORT_RETURN_IF_NOT(params.M % params.group == 0, "Output channels must be divisible by group."); + + if (B_shape != nullptr) { + ORT_RETURN_IF_NOT(B_shape->NumDimensions() == 1, "Bias B must be 1D."); + ORT_RETURN_IF_NOT((*B_shape)[0] == params.M, "Bias B must have shape [M] (M = number of output channels)."); + } + + // Validate mask if present + if (params.use_mask) { + ORT_RETURN_IF_NOT(mask_shape->NumDimensions() == 4, "Mask must be 4D."); + ORT_RETURN_IF_NOT((*mask_shape)[0] == params.N, "Mask batch size must match input batch size."); + const int64_t mask_block = params.kH * params.kW; + ORT_RETURN_IF_NOT(mask_block > 0 && (*mask_shape)[1] % mask_block == 0 && + (*mask_shape)[1] / mask_block == params.offset_group, + "Mask channel count must be offset_group * kH * kW."); + ORT_RETURN_IF_NOT((*mask_shape)[2] == params.out_h, "Mask spatial height must match output oH."); + ORT_RETURN_IF_NOT((*mask_shape)[3] == params.out_w, "Mask spatial width must match output oW."); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b87cf8cbc16c1..4f36789191bb2 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1455,6 +1455,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, DeformConv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Cast); @@ -1596,6 +1599,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, DeformConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSigmoid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSigmoid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSigmoid); @@ -2574,6 +2581,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2706,6 +2716,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc new file mode 100644 index 0000000000000..7a0b896acfe01 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -0,0 +1,331 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CUDA implementation of DeformConv (deformable convolution 2D). + +#include "core/providers/shared_library/provider_api.h" +#include "deform_conv.h" +#include "deform_conv_impl.h" + +#include + +#include "core/common/narrow.h" +#include "core/common/span_utils.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace cuda { + +namespace { + +constexpr int kMaxParallelImgs = 32; + +// Returns the greatest divisor of n that is <= bound. Used to choose uniform batch chunk sizes. +// Fast path: if n % bound == 0 (common for batch 32/64/128), return immediately. +// When n >= bound^2, linear scan from bound down is O(bound). Otherwise divisor enumeration +// from 1 to sqrt(n) is O(sqrt(n)). Uses integer comparison (no sqrt) for branch decision. +int GetGreatestDivisorBelowBound(int n, int bound) { + if (bound <= 0 || n <= 0) return 1; + if (n % bound == 0) return bound; // Fast path: batch is multiple of target + + // n >= bound^2 <=> bound <= sqrt(n) => linear scan is cheaper + if (static_cast(n) >= static_cast(bound) * bound) { + for (int k = bound - 1; k > 1; --k) { + if (n % k == 0) return k; + } + } else { + // n < bound^2 <=> bound > sqrt(n) => divisor enumeration is cheaper + int best = 1; + for (int i = 1; static_cast(i) * i <= static_cast(n); ++i) { + if (n % i != 0) continue; + const int q = n / i; + if (q <= bound && q > best) best = q; + if (i <= bound && i > best) best = i; + } + return best; + } + return 1; +} + +// Returns the maximum temp memory (bytes) allowed for DeformConv's im2col + GEMM buffers. +// Uses a fraction of total GPU memory to avoid OOM while leaving room for weights, activations, +// and other ops. No CUDA API is called; total_global_mem is expected from cached device props. +// +// Formula: +// budget = total_global_mem * kFraction +// return clamp(budget, kMin, kMax) +// with kFraction = 0.1 (10%), kMin = 32 MiB, kMax = 2 GiB. +// +// Example results (effective_max_temp after clamp): +// GPU | totalGlobalMem | effective_max_temp +// -----------------|----------------|-------------------- +// A100 80GB | 80 GiB | 2 GiB (capped) +// RTX 5080 16GB | 16 GiB | 1.6 GiB +// RTX 4090 24GB | 24 GiB | 2 GiB (capped) +// RTX 3080 10GB | 10 GiB | 1 GiB +// GTX 1060 6GB | 6 GiB | 614.4 MiB +// GTX 1050 4GB | 4 GiB | 409.6 MiB +// Jetson 2GB | 2 GiB | 204.8 MiB +size_t GetDeformConvEffectiveMaxTempBytes(size_t total_global_mem) { + constexpr double kFraction = 0.1; + constexpr size_t kMin = 32ULL * 1024 * 1024; + constexpr size_t kMax = 2ULL * 1024 * 1024 * 1024; + size_t budget = static_cast(static_cast(total_global_mem) * kFraction); + return std::clamp(budget, kMin, kMax); +} + +// Returns how many images to process in parallel per batch chunk for DeformConv. +// Chooses the largest divisor of batch size N that fits in the temp budget and does not +// exceed kMaxParallelImgs, so that batch dimension is split evenly (no remainder). +// Note: if N is prime and N > target_parallel_imgs, the greatest divisor <= target_parallel_imgs is 1, +// so batching is effectively disabled (single-image chunks). +// +// Formulas: +// kernel_size = kH * kW +// output_image_size = out_h * out_w +// bytes_per_image = output_image_size * (C * kernel_size + M / group) * sizeof(T) +// (temp bytes per image: im2col col buffer + GEMM output buffer per output position) +// max_parallel_imgs_mem = max(1, floor(effective_max_temp / bytes_per_image)) +// target_parallel_imgs = min(kMaxParallelImgs, max_parallel_imgs_mem) +// return GetGreatestDivisorBelowBound(N, target_parallel_imgs) +template +int GetNParallelImgs(const DeformConvParams& params, size_t total_global_mem) { + const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(total_global_mem); + const int64_t kernel_size = params.kH * params.kW; + const int64_t output_image_size = params.out_h * params.out_w; + const size_t bytes_per_image = SafeInt(output_image_size) * (params.C * kernel_size + params.M / params.group) * sizeof(T); + const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image))); + const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); + return GetGreatestDivisorBelowBound(static_cast(params.N), target_parallel_imgs); +} + +} // namespace + +template +Status DeformConv::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto* offset = context->Input(2); + const auto* B = context->Input(3); + const auto* mask = context->Input(4); + + DeformConvParams params; + ORT_RETURN_IF_ERROR(DeformConvValidateAndParse( + attrs_, + X->Shape(), + W->Shape(), + offset->Shape(), + B ? &B->Shape() : nullptr, + mask ? &mask->Shape() : nullptr, + params)); + + const int64_t N = params.N; + const int64_t C = params.C; + const int64_t H = params.H; + const int64_t W_in = params.W_in; + const int64_t M = params.M; + const int64_t kH = params.kH; + const int64_t kW = params.kW; + const int64_t pad_h = params.pad_h; + const int64_t pad_w = params.pad_w; + const int64_t stride_h = params.stride_h; + const int64_t stride_w = params.stride_w; + const int64_t dilation_h = params.dilation_h; + const int64_t dilation_w = params.dilation_w; + const int64_t group = params.group; + const int64_t offset_group = params.offset_group; + const int64_t out_h = params.out_h; + const int64_t out_w = params.out_w; + const bool use_mask = params.use_mask; + + Tensor* Y = context->Output(0, {N, M, out_h, out_w}); + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + const int n_parallel_imgs = GetNParallelImgs(params, GetDeviceProp().totalGlobalMem); + + const int64_t kernel_size = kH * kW; + const int64_t output_image_size = out_h * out_w; + const int64_t input_image_size = H * W_in; + const int64_t kernel_dim = (C / group) * kernel_size; + + const int64_t col_stride = static_cast(n_parallel_imgs) * output_image_size; + const int64_t col_buffer_size = (C * kernel_size) * col_stride; + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size)); + // Removed col_transposed allocation as we avoid physical transpose. + auto gemm_output_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt((M / group) * col_stride)); + + const T* Xdata = X->Data(); + const T* Wdata = W->Data(); + const T* offset_data = offset->Data(); + const T* mask_data = use_mask ? mask->Data() : nullptr; + T* Ydata = Y->MutableData(); + const T* Bdata = (B != nullptr) ? B->Data() : nullptr; + + cudaStream_t stream = Stream(context); + cublasHandle_t cublas = GetCublasHandle(context); + const cudaDeviceProp& device_prop = GetDeviceProp(); + CudaT alpha = ToCudaType::FromFloat(1.0f); + CudaT beta = ToCudaType::FromFloat(0.0f); + + for (int64_t b = 0; b < N; b += n_parallel_imgs) { + const int cur_parallel = static_cast(std::min(static_cast(n_parallel_imgs), N - b)); + const int64_t cur_out_size = static_cast(cur_parallel) * output_image_size; + + const T* X_block = Xdata + b * (C * input_image_size); + const T* offset_block = offset_data + b * (offset_group * 2 * kernel_size * output_image_size); + const T* mask_block = use_mask ? (mask_data + b * (offset_group * kernel_size * output_image_size)) : nullptr; + + ORT_RETURN_IF_ERROR(DeformConvIm2ColImpl( + stream, + X_block, + offset_block, + mask_block, + col_buffer.get(), + cur_parallel, + C, + H, + W_in, + kH, + kW, + out_h, + out_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + offset_group, + use_mask)); + + // GEMM layout trick: compute Y = W * Col without physical transpose. + // + // Our data is row-major: W [M/group, kernel_dim], Col [kernel_dim, cur_out_size], Y [M/group, cur_out_size]. + // cuBLAS is column-major. Key insight: row-major A[M,K] in memory equals column-major A^T[K,M]. + // We compute Y^T = Col^T * W^T by passing Col as A and W as B, both OP_N (no transpose): + // - Col (row [kernel_dim, cur_out_size]) -> cuBLAS interprets as col-major [cur_out_size, kernel_dim] = Col^T + // - W (row [M/group, kernel_dim]) -> cuBLAS interprets as col-major [kernel_dim, M/group] = W^T + // - C = A*B = Col^T * W^T = (W*Col)^T = Y^T; C is col-major [cur_out_size, M/group] = Y in row-major + // + // m=cur_out_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim, ldc=cur_out_size. + // + // cur_parallel==1: cur_out_size==output_image_size, C layout (pos, channel) matches NCHW Y_g[0,ch,pos] -> write + // directly into Y_g. Use strided batched for all groups in one call. + // cur_parallel>1: layouts differ -> write to gemm_output_buffer, then DeformConvCopyGemmOutputRowMajorToNCHW. + + const bool gemm_writes_directly = (cur_parallel == 1); + if (gemm_writes_directly) { + // Strided batched: one call for all groups. Strides between batches: + const int64_t stride_col = kernel_dim * col_stride; // = kernel_dim * output_image_size when cur_parallel==1 + const int64_t stride_weight = (M / group) * kernel_dim; + const int64_t stride_y = (M / group) * output_image_size; + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + narrow(output_image_size), + narrow(M / group), + narrow(kernel_dim), + &alpha, + reinterpret_cast(col_buffer.get()), + narrow(output_image_size), + stride_col, + reinterpret_cast(Wdata), + narrow(kernel_dim), + stride_weight, + &beta, + reinterpret_cast(Ydata + b * M * output_image_size), + narrow(output_image_size), + stride_y, + narrow(group), + device_prop, + UseTF32())); + } else { + // cur_parallel>1: GEMM output layout differs from NCHW; write to buffer then copy per group. + for (int64_t g = 0; g < group; ++g) { + const T* W_g = Wdata + g * (M / group) * kernel_dim; + const T* col_g = col_buffer.get() + g * kernel_dim * col_stride; + T* Y_g = Ydata + b * M * output_image_size + g * (M / group) * output_image_size; + + CUBLAS_RETURN_IF_ERROR((cublasGemmHelper( + cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + narrow(cur_out_size), + narrow(M / group), + narrow(kernel_dim), + &alpha, + reinterpret_cast(col_g), + narrow(cur_out_size), + reinterpret_cast(W_g), + narrow(kernel_dim), + &beta, + reinterpret_cast(gemm_output_buffer.get()), + narrow(cur_out_size), + device_prop, + UseTF32()))); + + ORT_RETURN_IF_ERROR(DeformConvCopyGemmOutputRowMajorToNCHW( + stream, + gemm_output_buffer.get(), + Y_g, + M, + M / group, + output_image_size, + cur_parallel)); + } + } + } + + if (Bdata != nullptr) { + ORT_RETURN_IF_ERROR(DeformConvAddBiasImpl(stream, Ydata, Bdata, N, M, out_h, out_w)); + } + + return Status::OK(); +} + +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + DeformConv, \ + kOnnxDomain, \ + 19, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DeformConv, \ + kOnnxDomain, \ + 22, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv) + +REGISTER_DEFORMCONV_KERNEL_TYPED(float); +REGISTER_DEFORMCONV_KERNEL_TYPED(double); +REGISTER_DEFORMCONV_KERNEL_TYPED(MLFloat16); + +// BFloat16 only for opset 22; opset 19-21 do not support BFloat16. +ONNX_OPERATOR_TYPED_KERNEL_EX( + DeformConv, + kOnnxDomain, + 22, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + DeformConv); + +#undef REGISTER_DEFORMCONV_KERNEL_TYPED + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.h b/onnxruntime/core/providers/cuda/nn/deform_conv.h new file mode 100644 index 0000000000000..fa564641d4b98 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/deform_conv_attributes.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace cuda { + +template +class DeformConv final : public CudaKernel { + public: + explicit DeformConv(const OpKernelInfo& info) : CudaKernel(info), attrs_(info) {} + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + DeformConvAttributes attrs_; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeformConv); +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu new file mode 100644 index 0000000000000..7b3666fca810b --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -0,0 +1,512 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CUDA implementation of DeformConv: deformable im2col kernel + bilinear interpolation. +// Reference: torchvision deform_conv2d_kernel.cu, ONNX DeformConv spec. + +#include "deform_conv_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/fast_divmod.h" +#include "core/common/float16.h" +#include +#include +#include + +namespace onnxruntime { +namespace cuda { + +namespace { + +constexpr int kDeformConvThreadsPerBlock = 256; + +template +struct DeformConvKSize { + static constexpr int value = N; +}; + +// Calculate grid size with a safety limit to prevent overflow. +// Since we use grid-stride loops in kernels, limiting the grid size is safe. +inline int GetGridSize(size_t n, size_t threads_per_block) { + size_t blocks_needed = (n + threads_per_block - 1) / threads_per_block; + return static_cast(std::min(blocks_needed, static_cast(std::numeric_limits::max()))); +} + +// __ldg has no overload for BFloat16*; use 16-bit load + FromBits. Other types use __ldg directly. +template +__device__ __inline__ T DeformConvLdg(const T* p) { + return __ldg(p); +} +template <> +__device__ __inline__ BFloat16 DeformConvLdg(const BFloat16* p) { + return BFloat16::FromBits(__ldg(reinterpret_cast(p))); +} + +// Traits for bilinear interpolation math: +// - ComputeT: type used for coordinate/weight math (float for half/BFloat16, T otherwise) +// - Load: load one element and convert to ComputeT +// - ToResult: convert ComputeT result back to T +// - Zero: zero value of T +template +struct DeformConvBilinearTraits { + using ComputeT = T; + + __device__ static __inline__ ComputeT Load(const T* p) { + return __ldg(p); + } + + __device__ static __inline__ T ToResult(ComputeT v) { + return v; + } + + __device__ static __inline__ T Zero() { + return static_cast(0); + } +}; + +template <> +struct DeformConvBilinearTraits { + using ComputeT = float; + + __device__ static __inline__ ComputeT Load(const half* p) { + return __half2float(__ldg(p)); + } + + __device__ static __inline__ half ToResult(ComputeT v) { + return __float2half(v); + } + + __device__ static __inline__ half Zero() { + return __float2half(0.0f); + } +}; + +template <> +struct DeformConvBilinearTraits { + using ComputeT = float; + + __device__ static __inline__ ComputeT Load(const BFloat16* p) { + return static_cast(DeformConvLdg(p)); + } + + __device__ static __inline__ BFloat16 ToResult(ComputeT v) { + return BFloat16(v); + } + + __device__ static __inline__ BFloat16 Zero() { + return BFloat16(0.0f); + } +}; + +// Bilinear interpolation at (h, w). Returns 0 if out of bounds (ONNX spec). +// Indices h_low, w_low, h_high, w_high use int (not int64_t) to reduce register pressure and +// improve occupancy in the hot path. Limitation: (H+1)*W must not exceed INT_MAX; this is +// validated on the host side in DeformConvValidateAndParse to guarantee index math in int +// does not overflow. For half/BFloat16, coordinate and weight math use float via +// DeformConvBilinearTraits to avoid precision loss. We keep floor() results in CoordT and +// cast to int only for indices (h_low/w_low), which avoids unnecessary CoordT->int->CoordT +// round trips when computing lh/lw/hh/hw. +template +__device__ __inline__ T BilinearInterpolate( + const T* in, + int height, + int width, + typename DeformConvBilinearTraits::ComputeT h, + typename DeformConvBilinearTraits::ComputeT w) { + using Traits = DeformConvBilinearTraits; + using CoordT = typename Traits::ComputeT; + + // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case). + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return Traits::Zero(); + } + + // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw. + CoordT h_floor = _Floor(h); + CoordT w_floor = _Floor(w); + int h_low = static_cast(h_floor); + int w_low = static_cast(w_floor); + int h_high = h_low + 1; + int w_high = w_low + 1; + + CoordT lh = h - h_floor; + CoordT lw = w - w_floor; + CoordT hh = static_cast(1) - lh; + CoordT hw = static_cast(1) - lw; + + // [Optimization 3]: Avoid a second multiply for base_high. + // Original code computed both bases as: + // base_low = h_low * width; + // base_high = h_high * width; + // Since h_high = h_low + 1, we can rewrite base_high as base_low + width and + // save one integer multiply in the hot path: + // base_low = h_low * width; + // base_high = base_low + width; + int base_low = h_low * width; + int base_high = base_low + width; + + CoordT v1 = (h_low >= 0 && w_low >= 0) ? Traits::Load(in + base_low + w_low) : static_cast(0); + CoordT v2 = (h_low >= 0 && w_high < width) ? Traits::Load(in + base_low + w_high) : static_cast(0); + CoordT v3 = (h_high < height && w_low >= 0) ? Traits::Load(in + base_high + w_low) : static_cast(0); + CoordT v4 = (h_high < height && w_high < width) ? Traits::Load(in + base_high + w_high) : static_cast(0); + + CoordT w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + return Traits::ToResult(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); +} + +// kH/kW = -1 means dynamic (runtime); >= 0 means compile-time constant for loop unrolling. +template +__global__ void DeformableIm2ColKernel( + IndexT num_kernels, + const T* __restrict__ input, + const T* __restrict__ offset, + const T* __restrict__ mask, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t channels, + int64_t offset_group, + DivMod out_h_div, + DivMod out_w_div, + DivMod parallel_imgs_div, + DivMod channel_per_offset_grp_div, + bool use_mask, + T* __restrict__ data_col) { + constexpr bool is_fixed = (kH >= 0 && kW >= 0); + const int64_t h_dim = is_fixed ? kH : weight_h; + const int64_t w_dim = is_fixed ? kW : weight_w; + + // Reconstruct dimensions from DivMod objects + const int64_t out_h = out_h_div.d_; + const int64_t out_w = out_w_div.d_; + const int64_t parallel_imgs = parallel_imgs_div.d_; + + const int64_t out_size = out_h * out_w; + // The stride for data_col is (parallel_imgs * out_h * out_w) + const int64_t col_stride = parallel_imgs * out_size; + + using CoordT = typename DeformConvBilinearTraits::ComputeT; + + for (IndexT index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { + IndexT val = index; + IndexT out_x, out_y, out_b, in_c; + + // Fast division/modulo to recover coordinates + out_w_div.divmod(val, val, out_x); + out_h_div.divmod(val, val, out_y); + parallel_imgs_div.divmod(val, in_c, out_b); + + // [Optimization 3] Avoid expensive division if offset_group is 1 (very common case). + IndexT offset_grp = 0; + if (offset_group > 1) { + IndexT dummy; + channel_per_offset_grp_div.divmod(in_c, offset_grp, dummy); + } + + // [Optimization 2] Common Subexpression Elimination (CSE) & Pointer Arithmetic + // Pre-calculate base pointers to reduce integer arithmetic inside the inner loops. + + // 1. Input pointer base for this batch and channel. + const T* input_ptr = input + static_cast(out_b) * (channels * height * width) + static_cast(in_c) * (height * width); + + // 2. Spatial index in the output feature map. + const int64_t spatial_idx = static_cast(out_y) * out_w + static_cast(out_x); + + // 3. Offset pointer base calculation. + // Layout: (N, offset_groups, 2*KH*KW, OH, OW) + // We pre-calculate the pointer to the start of the specific (n, g) block, plus spatial_idx. + const int64_t offset_group_block_size = 2 * h_dim * w_dim * out_size; + const T* offset_ptr_base = offset + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * offset_group_block_size + spatial_idx; + + // 4. Mask pointer base calculation (if used). + // Layout: (N, offset_groups, KH*KW, OH, OW) + const T* mask_ptr_base = nullptr; + if (use_mask) { + const int64_t mask_group_block_size = h_dim * w_dim * out_size; + mask_ptr_base = mask + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * mask_group_block_size + spatial_idx; + } + + // 5. Output pointer base calculation. + // data_col Layout: (C * KH * KW, N * OH * OW) + // The current thread writes to the column `c_col` = (b * OH * OW) + spatial_idx. + // The starting row for this channel is `in_c * KH * KW`. + const int64_t c_col = static_cast(out_b) * out_size + spatial_idx; + T* data_col_ptr_base = data_col + (static_cast(in_c) * h_dim * w_dim) * col_stride + c_col; + + // 6. Pre-calculate invariant coordinate parts. + // Use float for coordinate math when T is half or BFloat16 to avoid precision loss. + const CoordT base_h_im = static_cast(out_y * stride_h - pad_h); + const CoordT base_w_im = static_cast(out_x * stride_w - pad_w); + + auto process_kernel_point = [&](int64_t i, int64_t j) { + const int64_t kernel_idx = i * w_dim + j; + T mask_val = static_cast(1); + if (use_mask) { + // Access mask using pre-calculated base and stride. + mask_val = DeformConvLdg(mask_ptr_base + kernel_idx * out_size); + } + + // Calculate offset pointers relative to the base. + // The offset tensor stores (y_offset, x_offset) pairs for each kernel weight. + // Stride between y_offset and x_offset is `out_size`. + const int64_t offset_offset_idx = (2 * kernel_idx) * out_size; + + const CoordT offset_h = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx)); + const CoordT offset_w = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx + out_size)); + + const CoordT h_im = base_h_im + static_cast(i * dilation_h) + offset_h; + const CoordT w_im = base_w_im + static_cast(j * dilation_w) + offset_w; + + // height/width are validated on host (DeformConvValidateAndParse) so int is safe here. + T val = BilinearInterpolate(input_ptr, + static_cast(height), + static_cast(width), + h_im, + w_im); + + // Match CPU path: always interpolate then apply mask to keep branch-free hot loop. + data_col_ptr_base[kernel_idx * col_stride] = val * mask_val; + }; + + if constexpr (is_fixed) { +#pragma unroll + for (int i = 0; i < kH; ++i) { +#pragma unroll + for (int j = 0; j < kW; ++j) { + process_kernel_point(i, j); + } + } + } else { + for (int64_t i = 0; i < weight_h; ++i) { + for (int64_t j = 0; j < weight_w; ++j) { + process_kernel_point(i, j); + } + } + } + } +} + +// Bias add: Y[n,m,oh,ow] += B[m]. Layout NCHW. +template +__global__ void DeformConvAddBiasKernel( + T* Y, + const T* B, + DivMod spatial_div, // For dividing by (H * W) + DivMod channel_div, // For dividing by M (channel count) + int64_t total_elements) { + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += blockDim.x * gridDim.x) { + int64_t val = idx; + int64_t batch_channel_idx, pixel_idx; + + // 1. First decomposition: decompose idx into (batch_channel_idx, pixel_idx) + // Equivalent to: batch_channel_idx = idx / (H*W); pixel_idx = idx % (H*W); + spatial_div.divmod(val, batch_channel_idx, pixel_idx); + + int64_t batch_idx, channel_idx; + + // 2. Second decomposition: decompose batch_channel_idx into (batch_idx, channel_idx) + // Equivalent to: channel_idx = batch_channel_idx % M; + // We only need channel_idx (i.e. m) + channel_div.divmod(batch_channel_idx, batch_idx, channel_idx); + (void)batch_idx; // Only channel_idx is needed + + // channel_idx is what we need (i.e. m) + Y[idx] += DeformConvLdg(B + channel_idx); + } +} + +// Copy GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) into NCHW Y_g. +// src(c, j) with j = b_idx*output_image_size + pos -> dst[b_idx*M*output_image_size + c*output_image_size + pos]. +template +__global__ void CopyGemmOutputRowMajorToNCHWKernel( + const T* __restrict__ src, + T* __restrict__ dst, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel) { + int64_t total = cur_parallel * M_per_group * output_image_size; + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { + int64_t pos = idx % output_image_size; + int64_t c = (idx / output_image_size) % M_per_group; + int64_t b_idx = idx / (output_image_size * M_per_group); + int64_t j = b_idx * output_image_size + pos; + // src index for row-major: c * (cur_parallel * output_image_size) + j + dst[b_idx * M * output_image_size + c * output_image_size + pos] = src[c * (cur_parallel * output_image_size) + j]; + } +} + +} // namespace + +template +Status DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) { + int64_t total = N * M * out_h * out_w; + if (total <= 0) return Status::OK(); + + // 1. Prepare divisor + int64_t out_size = out_h * out_w; + + // 2. Create FastDivMod object (note: ensure int64_t version of DivMod is used here) + DivMod spatial_div(out_size); + DivMod channel_div(M); + + int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); + + // 3. Pass DivMod objects + DeformConvAddBiasKernel<<>>( + Y, + B, + spatial_div, + channel_div, + total); + return CUDA_CALL(cudaGetLastError()); +} + +template +Status DeformConvCopyGemmOutputRowMajorToNCHW( + cudaStream_t stream, + const T* gemm_output, + T* Y_g, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel) { + int64_t total = cur_parallel * M_per_group * output_image_size; + if (total <= 0) return Status::OK(); + int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); + CopyGemmOutputRowMajorToNCHWKernel<<>>( + gemm_output, Y_g, M, M_per_group, output_image_size, cur_parallel); + return CUDA_CALL(cudaGetLastError()); +} + +template +Status DeformConvIm2ColImpl( + cudaStream_t stream, + const T* input, + const T* offset, + const T* mask, + T* col_buffer, + int64_t parallel_imgs, + int64_t C, + int64_t H, + int64_t W, + int64_t kH, + int64_t kW, + int64_t out_h, + int64_t out_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t offset_group, + bool use_mask) { + const int64_t num_kernels = static_cast(C) * out_h * out_w * parallel_imgs; + if (num_kernels <= 0) { + return Status::OK(); + } + + const int64_t col_numel = static_cast(C) * kH * kW * parallel_imgs * out_h * out_w; + const bool use_64bit = (num_kernels > static_cast(std::numeric_limits::max())) || + (col_numel > static_cast(std::numeric_limits::max())); + + int blocks = GetGridSize(static_cast(num_kernels), kDeformConvThreadsPerBlock); + + auto launch = [&](auto kH_tag, auto kW_tag) { + constexpr int KH = decltype(kH_tag)::value; + constexpr int KW = decltype(kW_tag)::value; + if (use_64bit) { + DeformableIm2ColKernel<<>>( + num_kernels, input, offset, mask, H, W, kH, kW, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, C, offset_group, + DivMod(out_h), DivMod(out_w), DivMod(parallel_imgs), + DivMod(C / offset_group), use_mask, col_buffer); + } else { + DeformableIm2ColKernel<<>>( + static_cast(num_kernels), input, offset, mask, H, W, kH, kW, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, C, offset_group, + DivMod(static_cast(out_h)), + DivMod(static_cast(out_w)), + DivMod(static_cast(parallel_imgs)), + DivMod(static_cast(C / offset_group)), + use_mask, col_buffer); + } + }; + + if (kH == 1 && kW == 1) { + launch(DeformConvKSize<1>{}, DeformConvKSize<1>{}); + } else if (kH == 3 && kW == 3) { + launch(DeformConvKSize<3>{}, DeformConvKSize<3>{}); + } else if (kH == 5 && kW == 5) { + launch(DeformConvKSize<5>{}, DeformConvKSize<5>{}); + } else { + launch(DeformConvKSize<-1>{}, DeformConvKSize<-1>{}); + } + return CUDA_CALL(cudaGetLastError()); +} + +#define INST_DeformConvIm2ColImpl(T) \ + template Status DeformConvIm2ColImpl(cudaStream_t, const T*, const T*, const T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool) + +INST_DeformConvIm2ColImpl(float); +INST_DeformConvIm2ColImpl(double); +INST_DeformConvIm2ColImpl(half); +INST_DeformConvIm2ColImpl(BFloat16); + +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const double*, double*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const half*, half*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const BFloat16*, BFloat16*, int64_t, int64_t, int64_t, int64_t); + +template Status DeformConvAddBiasImpl(cudaStream_t, float*, const float*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, double*, const double*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, half*, const half*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFloat16*, int64_t, int64_t, int64_t, int64_t); + +// Delegate ORT type to CUDA type (e.g. MLFloat16 -> half); avoids repeating three identical specializations. +#define DELEGATE_DEFORM_CONV_IMPL(ORT_T, CUDA_T) \ + template <> \ + Status DeformConvIm2ColImpl(cudaStream_t stream, const ORT_T* input, \ + const ORT_T* offset, const ORT_T* mask, ORT_T* col_buffer, \ + int64_t parallel_imgs, int64_t C, int64_t H, int64_t W, \ + int64_t kH, int64_t kW, int64_t out_h, int64_t out_w, \ + int64_t pad_h, int64_t pad_w, int64_t stride_h, int64_t stride_w, \ + int64_t dilation_h, int64_t dilation_w, int64_t offset_group, bool use_mask) { \ + return DeformConvIm2ColImpl(stream, reinterpret_cast(input), \ + reinterpret_cast(offset), \ + mask ? reinterpret_cast(mask) : nullptr, \ + reinterpret_cast(col_buffer), \ + parallel_imgs, C, H, W, kH, kW, out_h, out_w, \ + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, \ + offset_group, use_mask); \ + } \ + template <> \ + Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ + const ORT_T* gemm_output, ORT_T* Y_g, \ + int64_t M, int64_t M_per_group, \ + int64_t output_image_size, int64_t cur_parallel) { \ + return DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ + reinterpret_cast(gemm_output), \ + reinterpret_cast(Y_g), \ + M, M_per_group, output_image_size, cur_parallel); \ + } \ + template <> \ + Status DeformConvAddBiasImpl(cudaStream_t stream, ORT_T * Y, const ORT_T* B, \ + int64_t N, int64_t M, int64_t out_h, int64_t out_w) { \ + return DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ + reinterpret_cast(B), N, M, out_h, out_w); \ + } + +// BFloat16 is not delegated: ORT's BFloat16 is the same type used in device code (ToCudaType in +// cuda_common.h), so the explicit instantiations above (INST_DeformConvIm2ColImpl(BFloat16), etc.) suffice. +DELEGATE_DEFORM_CONV_IMPL(MLFloat16, half) + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h new file mode 100644 index 0000000000000..0c26cb55311bc --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/status.h" + +namespace onnxruntime { +namespace cuda { + +// Adds bias to output: Y[n,m,oh,ow] += B[m]. Y is [N, M, out_h, out_w], B is [M]. +// T may be float, double, MLFloat16 (FP16), or BFloat16. +template +Status DeformConvAddBiasImpl( + cudaStream_t stream, + T* Y, + const T* B, + int64_t N, + int64_t M, + int64_t out_h, + int64_t out_w); + +// Copies GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) to NCHW slice at Y_g. +// T may be float, double, MLFloat16 (FP16), or BFloat16. +template +Status DeformConvCopyGemmOutputRowMajorToNCHW( + cudaStream_t stream, + const T* gemm_output, + T* Y_g, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel); + +// Fills col_buffer with deformable im2col. col_buffer layout: row-major [C*kH*kW, parallel_imgs*out_h*out_w]. +// Called once per batch block; caller does GEMM and bias. T may be float, double, MLFloat16 (FP16), or BFloat16. +template +Status DeformConvIm2ColImpl( + cudaStream_t stream, + const T* input, // [parallel_imgs, C, H, W] + const T* offset, // [parallel_imgs, offset_group*2*kH*kW, out_h, out_w] + const T* mask, // [parallel_imgs, offset_group*kH*kW, out_h, out_w] or nullptr + T* col_buffer, // [C*kH*kW, parallel_imgs*out_h*out_w] + int64_t parallel_imgs, + int64_t C, + int64_t H, + int64_t W, + int64_t kH, + int64_t kW, + int64_t out_h, + int64_t out_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t offset_group, + bool use_mask); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py new file mode 100644 index 0000000000000..a6c4923cd961e --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Generate expected outputs for DeformConv tests using torchvision.ops.deform_conv2d. +Run with: .venv/bin/python onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py +Outputs C++-friendly std::vector initializer lists for pasting into deform_conv_op_test.cc + +Limitation: Uses symmetric padding only. PyTorch padding=(pad_h, pad_w) and ONNX pads +[pad_h, pad_w, pad_h, pad_w] are derived from a single (pad_h, pad_w) pair. Asymmetric +pads (e.g. pad_top != pad_bottom) would require PyTorch API support and are not generated. +""" + +import torch +import torchvision.ops + + +def _pair(x: int | tuple[int, int]) -> tuple[int, int]: + if isinstance(x, int): + return (x, x) + return x + + +def to_cpp_list(t: torch.Tensor, fmt="{:.6f}") -> str: + """Flatten tensor in NCHW order and format as C++ initializer list.""" + t = t.detach().float().contiguous() + return ", ".join(fmt.format(x) + "f" for x in t.flatten().tolist()) + + +def run_case( + name: str, + batch_sz: int, + n_in: int, + n_out: int, + n_weight_grps: int, + n_offset_grps: int, + kernel_h: int, + kernel_w: int, + stride: tuple[int, int] | int, + pad: tuple[int, int] | int, + dilation: tuple[int, int] | int, + in_h: int, + in_w: int, + seed: int = 42, +): + """Build inputs with seed, run deform_conv2d, print C++ snippets.""" + torch.manual_seed(seed) + stride_h, stride_w = _pair(stride) + pad_h, pad_w = _pair(pad) + dil_h, dil_w = _pair(dilation) + + out_h = (in_h + 2 * pad_h - (dil_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (kernel_w - 1) + 1)) // stride_w + 1 + + x = torch.rand(batch_sz, n_in, in_h, in_w, dtype=torch.float32) + offset = torch.randn(batch_sz, n_offset_grps * 2 * kernel_h * kernel_w, out_h, out_w, dtype=torch.float32) + mask = torch.randn(batch_sz, n_offset_grps * kernel_h * kernel_w, out_h, out_w, dtype=torch.float32) + weight = torch.randn(n_out, n_in // n_weight_grps, kernel_h, kernel_w, dtype=torch.float32) + bias = torch.randn(n_out, dtype=torch.float32) + + # Standard answer from torchvision + out = torchvision.ops.deform_conv2d( + x, + offset, + weight, + bias=bias, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dil_h, dil_w), + mask=mask, + ) + + # ONNX pads = [top, left, bottom, right] (symmetric: single pad_h, pad_w expanded) + pads_onnx = [pad_h, pad_w, pad_h, pad_w] + + print(f"// --- {name} (seed={seed}) ---") + print(f"// Shapes: X({batch_sz},{n_in},{in_h},{in_w}) W({n_out},{n_in // n_weight_grps},{kernel_h},{kernel_w})") + print(f"// stride=({stride_h},{stride_w}) pad=({pad_h},{pad_w}) dilation=({dil_h},{dil_w})") + print(f"// out_h={out_h} out_w={out_w}") + print() + print("std::vector X = {" + to_cpp_list(x) + "};") + print("std::vector W = {" + to_cpp_list(weight) + "};") + print("std::vector offset = {" + to_cpp_list(offset) + "};") + print("std::vector B = {" + to_cpp_list(bias) + "};") + print("std::vector mask = {" + to_cpp_list(mask) + "};") + print("std::vector expected_Y = {" + to_cpp_list(out) + "};") + print() + print( + "// Params: kernel_shape={" + f"{kernel_h}, {kernel_w}" + "}, stride={" + f"{stride_h}, {stride_w}" + "}, pads={" + + ", ".join(map(str, pads_onnx)) + + "}, dilations={" + + f"{dil_h}, {dil_w}" + + "}, group=" + + str(n_weight_grps) + + ", offset_group=" + + str(n_offset_grps) + ) + print() + return out + + +def main(): + print("// Generated by deform_conv_expected_gen.py (torchvision.ops.deform_conv2d)") + print() + + # Case 1: Same config as PyTorch TestDeformConv.get_fn_args (small batch for readability) + run_case( + "PyTorch get_fn_args style (batch=1)", + batch_sz=1, + n_in=6, + n_out=2, + n_weight_grps=2, + n_offset_grps=3, + kernel_h=3, + kernel_w=2, + stride=(2, 1), + pad=(1, 0), + dilation=(2, 1), + in_h=5, + in_w=4, + seed=42, + ) + + # Case 2: No mask (mask optional) - same config, then expected with mask=None + torch.manual_seed(42) + n_in, n_out = 6, 2 + n_weight_grps, n_offset_grps = 2, 3 + kH, kW = 3, 2 # noqa: N806 + stride_h, stride_w = 2, 1 + pad_h, pad_w = 1, 0 + dil_h, dil_w = 2, 1 + in_h, in_w = 5, 4 + batch_sz = 1 + out_h = (in_h + 2 * pad_h - (dil_h * (kH - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (kW - 1) + 1)) // stride_w + 1 + + x = torch.rand(batch_sz, n_in, in_h, in_w, dtype=torch.float32) + offset = torch.randn(batch_sz, n_offset_grps * 2 * kH * kW, out_h, out_w, dtype=torch.float32) + weight = torch.randn(n_out, n_in // n_weight_grps, kH, kW, dtype=torch.float32) + bias = torch.randn(n_out, dtype=torch.float32) + + out_no_mask = torchvision.ops.deform_conv2d( + x, + offset, + weight, + bias=bias, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dil_h, dil_w), + mask=None, + ) + print("// --- Same inputs, no mask (expected_Y when mask is omitted) ---") + print("std::vector expected_Y_no_mask = {" + to_cpp_list(out_no_mask) + "};") + print() + + # Case 3: groups=2, offset_group=2, non-zero offset (for GroupsWithNonZeroOffset test) + run_case( + "Groups with non-zero offset (batch=1, 2 groups)", + batch_sz=1, + n_in=4, + n_out=2, + n_weight_grps=2, + n_offset_grps=2, + kernel_h=2, + kernel_w=2, + stride=(1, 1), + pad=(0, 0), + dilation=(1, 1), + in_h=3, + in_w=3, + seed=123, + ) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc new file mode 100644 index 0000000000000..860c0d2f08b18 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -0,0 +1,948 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Unit tests for DeformConv (CPU and Cuda), aligned with PyTorch Vision deform_conv2d tests. +// Reference: https://github.com/pytorch/vision/blob/main/test/test_ops.py (TestDeformConv) + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/testdata/deform_conv_test_data.inc" +#include "test/unittest_util/conversion.h" + +#if defined(USE_CUDA) +#include "test/common/cuda_op_test_utils.h" +#endif + +namespace onnxruntime { +namespace test { + +namespace { + +// Parameters similar to PyTorch TestDeformConv::get_fn_args (smaller for speed). +struct DeformConvTestParams { + int64_t batch_sz; + int64_t n_in_channels; + int64_t n_out_channels; + int64_t n_weight_grps; + int64_t n_offset_grps; + std::vector kernel_shape; // {kH, kW} + std::vector stride; + std::vector pad; + std::vector dilation; + int64_t in_h; + int64_t in_w; +}; + +// Traits for type-specific DeformConv test behavior. +template +struct DeformConvTestTraits; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return v; } + static std::unordered_set ExcludedProviders() { + return {kTensorrtExecutionProvider, kNvTensorRTRTXExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-5f; } + static constexpr float DefaultAtol() { return 1e-5f; } +}; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return FloatsToMLFloat16s(v); } + static std::unordered_set ExcludedProviders() { + return {kCpuExecutionProvider, kNvTensorRTRTXExecutionProvider, kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-2f; } + static constexpr float DefaultAtol() { return 1e-2f; } +}; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { + return std::vector(v.begin(), v.end()); + } + static std::unordered_set ExcludedProviders() { + return {kTensorrtExecutionProvider, kNvTensorRTRTXExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; + } + static constexpr double DefaultRtol() { return 1e-8; } + static constexpr double DefaultAtol() { return 1e-8; } +}; + +#if defined(USE_CUDA) +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return FloatsToBFloat16s(v); } + static std::unordered_set ExcludedProviders() { + return {kCpuExecutionProvider, kNvTensorRTRTXExecutionProvider, kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-2f; } + static constexpr float DefaultAtol() { return 1e-2f; } +}; +#endif + +template +void RunDeformConvTest(const DeformConvTestParams& params, + const std::vector& X, + const std::vector& W, + const std::vector& offset, + const std::vector& B, + const std::vector* mask, + const std::vector& expected_Y, + int opset = 19, + decltype(DeformConvTestTraits::DefaultRtol()) rtol = DeformConvTestTraits::DefaultRtol(), + decltype(DeformConvTestTraits::DefaultAtol()) atol = DeformConvTestTraits::DefaultAtol(), + bool omit_bias = false) { + const int64_t kH = params.kernel_shape[0]; + const int64_t kW = params.kernel_shape[1]; + // ONNX pads format: [pad_top, pad_left, pad_bottom, pad_right] = [pad[0], pad[1], pad[2], pad[3]] + const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - + params.dilation[0] * (kH - 1) - 1) / + params.stride[0] + + 1; + const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - + params.dilation[1] * (kW - 1) - 1) / + params.stride[1] + + 1; + + OpTester test("DeformConv", opset); + test.AddAttribute("kernel_shape", params.kernel_shape); + test.AddAttribute("strides", params.stride); + test.AddAttribute("pads", params.pad); + test.AddAttribute("dilations", params.dilation); + test.AddAttribute("group", params.n_weight_grps); + test.AddAttribute("offset_group", params.n_offset_grps); + + const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; + const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; + const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; + const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; + + auto X_t = DeformConvTestTraits::Convert(X); + auto W_t = DeformConvTestTraits::Convert(W); + auto offset_t = DeformConvTestTraits::Convert(offset); + auto expected_Y_t = DeformConvTestTraits::Convert(expected_Y); + + test.AddInput("X", X_shape, X_t); + test.AddInput("W", W_shape, W_t); + test.AddInput("offset", offset_shape, offset_t); + if (omit_bias) { + test.AddOptionalInputEdge(); + } else { + auto B_t = DeformConvTestTraits::Convert(B); + test.AddInput("B", {params.n_out_channels}, B_t); + } + if (mask != nullptr) { + const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; + test.AddInput("mask", mask_shape, DeformConvTestTraits::Convert(*mask)); + } else { + test.AddOptionalInputEdge(); + } + + const float rtol_f = static_cast(rtol); + const float atol_f = static_cast(atol); + test.AddOutput("Y", Y_shape, expected_Y_t, false, rtol_f, atol_f); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", DeformConvTestTraits::ExcludedProviders()); +} + +// MinimalBilinear test: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). +// At (0,0) offset (0.5, 0.5) samples center of [1,2;3,4] -> 2.5. +template +void RunMinimalBilinearTest(int opset = 19, int min_cuda_arch = 0, bool omit_bias = false) { +#if defined(USE_CUDA) + if (min_cuda_arch > 0 && !HasCudaEnvironment(min_cuda_arch)) { + return; + } +#else + (void)min_cuda_arch; +#endif + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + // offset shape [N, 2*kH*kW, out_h, out_w] = [1, 2, 2, 2]: ch0=offset_h, ch1=offset_w (for kernel pt 0) + // Layout: offset[n,c,oh,ow]. Flattened (NCHW): [ch0@00, ch0@01, ch0@10, ch0@11, ch1@00, ch1@01, ch1@10, ch1@11] + // (0,0): (0.5, 0.5)->center of [1,2;3,4]->2.5; (0,1): (0,-1)->(0,0)->1; (1,0): (0,0)->3; (1,1): (0,0)->4 + std::vector offset = {0.5f, 0.f, 0.f, 0.f, 0.5f, -1.0f, 0.f, 0.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; + if (omit_bias) { + RunDeformConvTest(p, X, W, offset, {} /* B unused */, &mask, expected_Y, opset, + DeformConvTestTraits::DefaultRtol(), DeformConvTestTraits::DefaultAtol(), true); + } else { + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, opset, + DeformConvTestTraits::DefaultRtol(), DeformConvTestTraits::DefaultAtol(), false); + } +} +} // namespace + +// Minimal case: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). +TEST(DeformConvTest, MinimalBilinear) { + RunMinimalBilinearTest(); +} + +// Optional bias omitted: same as MinimalBilinear but B is not provided; output must match B=0. +TEST(DeformConvTest, OptionalBiasOmitted) { + RunMinimalBilinearTest(19, 0, true); +} + +// Minimal case FP16: Same as MinimalBilinear but in FP16 (CUDA-only). +#if defined(USE_CUDA) +TEST(DeformConvTest, MinimalBilinearFP16) { + RunMinimalBilinearTest(19, 530); +} + +// Minimal case BFloat16: Same as MinimalBilinear but in BFloat16 (CUDA-only, opset 22). +TEST(DeformConvTest, MinimalBilinearBFloat16) { + RunMinimalBilinearTest(22, 800); +} +#endif // defined(USE_CUDA) + +// Minimal case Double (FP64): Same as MinimalBilinear in double precision. +TEST(DeformConvTest, MinimalBilinearDouble) { + RunMinimalBilinearTest(); +} + +// Forward with mask and bias FP16 (CUDA-only; skip when CUDA not available). +#if defined(USE_CUDA) +TEST(DeformConvTest, ForwardWithMaskAndBiasFP16) { + int min_cuda_architecture = 530; // FP16 requires SM 5.3+ + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "DeformConv FP16: CUDA not available, skipping."; + return; + } + + DeformConvTestParams p = {}; + p.batch_sz = 2; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(p.batch_sz * p.n_in_channels * p.in_h * p.in_w); + const size_t w_size = static_cast(p.n_out_channels * (p.n_in_channels / p.n_weight_grps) * 2 * 2); + const size_t offset_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.5f, -0.5f}; + + const size_t y_size = static_cast(p.batch_sz * p.n_out_channels * out_h * out_w); + std::vector expected_Y(y_size); + for (int64_t b = 0; b < p.batch_sz; ++b) { + for (int64_t c = 0; c < p.n_out_channels; ++c) { + float val = (c % 2 == 0) ? 0.58f : -0.42f; + for (int64_t i = 0; i < out_h * out_w; ++i) { + expected_Y[b * p.n_out_channels * out_h * out_w + c * out_h * out_w + i] = val; + } + } + } + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} +#endif // defined(USE_CUDA) + +// With offset=0 and mask=1, Y = Conv(X,W) + B. Use small inputs and compute expected. +TEST(DeformConvTest, ForwardWithMaskAndBias) { + DeformConvTestParams p = {}; + p.batch_sz = 2; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(p.batch_sz * p.n_in_channels * p.in_h * p.in_w); + const size_t w_size = static_cast(p.n_out_channels * (p.n_in_channels / p.n_weight_grps) * 2 * 2); + const size_t offset_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); // zero offset -> regular grid sampling + std::vector mask(mask_size, 1.f); + std::vector B = {0.5f, -0.5f}; + + // With offset=0, mask=1: deform_conv equals grouped conv. Per ONNX, group 0 -> output ch 0, group 1 -> ch 1. + // Uniform X=0.1, W=0.1, 2x2 kernel -> 0.08 + B per channel; Y[:,0,:,:]=0.58, Y[:,1,:,:]=-0.42. + const size_t y_size = static_cast(p.batch_sz * p.n_out_channels * out_h * out_w); + std::vector expected_Y(y_size); + for (int64_t b = 0; b < p.batch_sz; ++b) { + for (int64_t c = 0; c < p.n_out_channels; ++c) { + float val = (c % 2 == 0) ? 0.58f : -0.42f; + for (int64_t i = 0; i < out_h * out_w; ++i) { + expected_Y[b * p.n_out_channels * out_h * out_w + c * out_h * out_w + i] = val; + } + } + } + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// No mask (optional): same as above but mask omitted; compare to run with ones mask via tolerance. +TEST(DeformConvTest, ForwardNoMask) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = 1 * 2 * 3 * 3; + const size_t w_size = 2 * 2 * 2 * 2; + const size_t offset_size = 1 * 2 * 2 * 2 * out_h * out_w; + const size_t y_size = 1 * 2 * out_h * out_w; + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector B(2, 0.f); + // No mask => mask=1. Zero offset => same as conv. Y = 4*2*0.1*0.1 = 0.08 per position. + std::vector expected_Y(y_size, 0.08f); + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + const std::vector X_shape = {p.batch_sz, p.n_in_channels, p.in_h, p.in_w}; + const std::vector W_shape = {p.n_out_channels, p.n_in_channels / p.n_weight_grps, 2, 2}; + const std::vector offset_shape = {p.batch_sz, p.n_offset_grps * 2 * 2 * 2, out_h, out_w}; + const std::vector Y_shape = {p.batch_sz, p.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W); + test.AddInput("offset", offset_shape, offset); + test.AddInput("B", {p.n_out_channels}, B); + test.AddOptionalInputEdge(); // no mask + test.AddOutput("Y", Y_shape, expected_Y, false, 1e-4f, 1e-4f); + std::unordered_set excluded = {kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +// Empty batch (N=0): allowed, same as Conv/ConvTranspose/Pool — output shape [0, oC, oH, oW]. +TEST(DeformConvTest, EmptyBatch) { + DeformConvTestParams p = {}; + p.batch_sz = 0; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X; + std::vector W = std::vector(2 * 2 * 2 * 2, 0.1f); + std::vector offset; + std::vector B(2, 0.f); + std::vector expected_Y; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + const std::vector X_shape = {0, p.n_in_channels, p.in_h, p.in_w}; + const std::vector W_shape = {p.n_out_channels, p.n_in_channels / p.n_weight_grps, 2, 2}; + const std::vector offset_shape = {0, p.n_offset_grps * 2 * 2 * 2, out_h, out_w}; + const std::vector Y_shape = {0, p.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W); + test.AddInput("offset", offset_shape, offset); + test.AddInput("B", {p.n_out_channels}, B); + test.AddOptionalInputEdge(); + test.AddOutput("Y", Y_shape, expected_Y); + std::unordered_set excluded = {kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +// Wrong offset channel count -> expect failure (like PyTorch test_wrong_sizes). +TEST(DeformConvTest, WrongOffsetShape) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 2 * 3 * 3, 0.1f); + std::vector W(2 * 2 * 2 * 2, 0.1f); + std::vector wrong_offset(1 * 2 * out_h * out_w); // wrong: only 2 channels instead of 8 + std::vector B(2, 0.f); + std::vector expected_Y(1 * 2 * out_h * out_w, 0.f); + + const std::vector offset_shape_wrong = {1, 2, out_h, out_w}; + const std::vector Y_shape_wrong = {1, 2, out_h, out_w}; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + test.AddInput("X", {1, 2, 3, 3}, X); + test.AddInput("W", {2, 2, 2, 2}, W); + test.AddInput("offset", offset_shape_wrong, wrong_offset); // invalid channels + test.AddInput("B", {2}, B); + test.AddOptionalInputEdge(); + test.AddOutput("Y", Y_shape_wrong, expected_Y); + std::unordered_set excluded = {kTensorrtExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectFailure, "Offset channel count must be offset_group * 2 * kH * kW", excluded); +} + +// Wrong mask channel count -> expect failure. +TEST(DeformConvTest, WrongMaskShape) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 2 * 3 * 3, 0.1f); + std::vector W(2 * 2 * 2 * 2, 0.1f); + const size_t offset_size = static_cast( + p.batch_sz * p.n_offset_grps * 2 * p.kernel_shape[0] * p.kernel_shape[1] * out_h * out_w); + std::vector offset(offset_size, 0.f); + std::vector B(2, 0.f); + std::vector wrong_mask(1 * 2 * out_h * out_w); // wrong: 2 instead of 4 + std::vector expected_Y(1 * 2 * out_h * out_w, 0.f); + + const std::vector mask_shape_wrong = {1, 2, out_h, out_w}; + const std::vector Y_shape_mask = {1, 2, out_h, out_w}; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + test.AddInput("X", {1, 2, 3, 3}, X); + test.AddInput("W", {2, 2, 2, 2}, W); + test.AddInput("offset", {1, 8, out_h, out_w}, offset); + test.AddInput("B", {2}, B); + test.AddInput("mask", mask_shape_wrong, wrong_mask); + test.AddOutput("Y", Y_shape_mask, expected_Y); + std::unordered_set excluded = {kTensorrtExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectFailure, "Mask channel count", excluded); +} + +// Opset 22 (same behavior, different opset). +TEST(DeformConvTest, Opset22) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + std::vector offset = {0.5f, 0.f, 0.f, 0.f, 0.5f, 0.f, 0.f, 0.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {2.5f, 2.f, 3.f, 4.f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 22); +} + +// Non-square kernel (kH != kW): 2x3 kernel, zero offset -> same as standard conv. +TEST(DeformConvTest, NonSquareKernel) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 3}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 4; + p.in_w = 5; + // ONNX output size: out_h = (4 - 1*(2-1) - 1)/1 + 1 = 3, out_w = (5 - 1*(3-1) - 1)/1 + 1 = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t x_size = static_cast(1 * 1 * 4 * 5); + const size_t w_size = static_cast(1 * 1 * 2 * 3); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 3 * out_h * out_w); // n_offset_grps * 2 * kH * kW * out_h * out_w + const size_t mask_size = static_cast(1 * 1 * 2 * 3 * out_h * out_w); // n_offset_grps * kH * kW * out_h * out_w + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // With offset=0, mask=1: each output = 6 * 0.1 * 0.1 = 0.06 (9 positions) + std::vector expected_Y(static_cast(out_h * out_w), 0.06f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Asymmetric stride (stride_h != stride_w): stride=(2,1), zero offset. +TEST(DeformConvTest, AsymmetricStride) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {2, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 5; + p.in_w = 4; + // out_h = (5 - 1*(2-1) - 1) / 2 + 1 = 2, out_w = (4 - 1*(2-1) - 1) / 1 + 1 = 3 + const int64_t out_h = 2; + const int64_t out_w = 3; + + const size_t x_size = static_cast(1 * 1 * 5 * 4); + const size_t w_size = static_cast(1 * 1 * 2 * 2); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + std::vector expected_Y(static_cast(out_h * out_w), 0.04f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// groups > 0 and non-zero offset; expected from deform_conv_expected_gen.py (seed=123). +TEST(DeformConvTest, GroupsWithNonZeroOffset) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + + std::vector X = {0.296112f, 0.516562f, 0.251671f, 0.688557f, 0.073972f, 0.866522f, 0.136580f, 0.102479f, 0.184056f, 0.726447f, 0.315254f, 0.687107f, 0.075635f, 0.196638f, 0.316412f, 0.401740f, 0.118568f, 0.827395f, 0.382084f, 0.660494f, 0.853572f, 0.593153f, 0.636725f, 0.982629f, 0.274495f, 0.658376f, 0.277542f, 0.857325f, 0.899328f, 0.039014f, 0.926823f, 0.738757f, 0.717884f, 0.705837f, 0.915650f, 0.433980f}; + std::vector W = {-1.182045f, -0.287745f, -0.604301f, 0.600237f, -1.420473f, -0.223828f, 0.430555f, -0.898857f, -0.017858f, 0.426403f, -0.765741f, -0.054514f, -0.732053f, 1.234742f, 1.186221f, -0.220099f}; + std::vector offset = {-0.388483f, -0.934346f, -0.499144f, -1.086653f, 0.962421f, 0.249208f, -0.484502f, -2.092915f, 0.098284f, -0.093507f, 0.266215f, -0.585035f, -0.343038f, -0.682148f, -0.988689f, -1.701830f, -1.220290f, 1.313853f, 1.053300f, 0.138805f, -0.204445f, -2.268529f, -0.913328f, -0.420363f, -0.659559f, -0.797928f, 0.183831f, 0.229347f, 0.617743f, -0.287578f, 0.821824f, 0.151178f, -0.044382f, 1.623557f, -2.322871f, 1.087831f, -0.063545f, -0.448641f, -1.278470f, -1.144004f, -0.152640f, 0.116741f, 0.440260f, -1.446546f, -0.558082f, -0.051696f, -0.908273f, 0.350683f, -0.394809f, 0.489227f, -0.216815f, -1.747165f, 1.722842f, 0.773806f, 0.404630f, -1.646126f, -0.595084f, -0.711218f, 0.622965f, -1.372881f, -0.128065f, -1.283835f, -0.290120f, 1.276741f}; + std::vector B = {0.983955f, 0.204512f}; + std::vector mask = {-0.031861f, -0.478956f, 0.766809f, 0.027468f, 0.047470f, -0.923866f, -1.060737f, -2.324446f, -2.062818f, 0.006375f, -0.989555f, 0.701609f, -0.982238f, 0.277031f, 0.645495f, -0.895681f, 0.492753f, -0.014078f, -0.274663f, -0.764091f, -0.587157f, 1.195165f, -1.209575f, -0.556008f, -0.077105f, 1.277377f, -1.459629f, -2.159528f, -0.706709f, -0.922245f, 3.895372f, -0.602697f}; + std::vector expected_Y = {0.971546f, 1.139858f, 0.452817f, 1.863882f, -0.565266f, 1.423187f, -2.462833f, -0.104923f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// Sampling out of bounds: offset pushes sampling to (-5,-5), BilinearInterpolate returns 0. +TEST(DeformConvTest, OutOfBoundsSampling) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + // out_h=out_w=2 (2x2 output), offset shape [1, 2, 2, 2] = 8 values. All (-5,-5) -> OOB -> 0 + std::vector offset = {-5.f, -5.f, -5.f, -5.f, -5.f, -5.f, -5.f, -5.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {0.f, 0.f, 0.f, 0.f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Dilation > 1: 2x2 kernel with dilation (2,2), zero offset -> 4 sample points with stride 2. +TEST(DeformConvTest, DilationGt1) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {2, 2}; + p.in_h = 5; + p.in_w = 5; + // out_h = (5 - 2*(2-1) - 1)/1 + 1 = 3, out_w = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t x_size = 25; + const size_t w_size = 4; + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Each output: 4 samples at (0,0),(0,2),(2,0),(2,2) -> 4 * 0.1 * 0.1 = 0.04 + std::vector expected_Y(static_cast(out_h * out_w), 0.04f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Decoupled groups: group=2, offset_group=1 (one offset map shared by all input channels). +TEST(DeformConvTest, DecoupledGroups) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(1 * 4 * 3 * 3); + const size_t w_size = static_cast(2 * 2 * 2 * 2); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f, 0.f}; + // Zero offset -> grouped conv. Per output ch: 2 in_ch * 4 kernel * 0.01 = 0.08 + std::vector expected_Y(static_cast(1 * 2 * out_h * out_w), 0.08f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Asymmetric padding: pads [top=1, left=0, bottom=0, right=1]; output 3x3, some positions have OOB samples. +TEST(DeformConvTest, AsymmetricPadding) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {1, 0, 0, 1}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + // out_h = (3+1+0-1*(2-1)-1)/1+1 = 3, out_w = (3+0+1-1-1)/1+1 = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(1 * 1 * 3 * 3, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Row 0: (0,0),(0,1) 2 valid -> 0.02; (0,2) only (0,2) in, (0,3) OOB -> 1 valid -> 0.01. Row 1/2: as before. + std::vector expected_Y = {0.02f, 0.02f, 0.01f, 0.04f, 0.04f, 0.02f, 0.04f, 0.04f, 0.02f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Tiny offset (near zero): offset (1e-6, 1e-6), sample ~(0,0) -> bilinear ≈ X[0,0]. Use 1x1 input for 1 output. +TEST(DeformConvTest, TinyOffset) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 1; + p.in_w = 1; + + std::vector X = {1.f}; + std::vector W = {1.f}; + std::vector offset = {1e-6f, 1e-6f}; + std::vector B = {0.f}; + std::vector mask = {1.f}; + std::vector expected_Y = {1.f}; // bilinear at (1e-6, 1e-6) ≈ 1 + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// Offset (0.5, 0.5) at each kernel point: sampling at (i+0.5, j+0.5) -> (0.5,0.5),(0.5,1.5),(1.5,0.5),(1.5,1.5). +// Only (0.5,0.5) is fully in-bounds for 2x2 input; others hit boundary (OOB gives 0). Result = 1.6875. +TEST(DeformConvTest, OffsetAtPixelCenters) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {0.25f, 0.25f, 0.25f, 0.25f}; + std::vector offset = { + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {1.6875f}; // op output: one center sample 2.5 + boundary samples + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Large batch (N=64) to trigger CUDA ComputeInternal chunking loop (b += n_parallel_imgs). +TEST(DeformConvTest, LargeBatchSize) { + const int64_t N = 64; + DeformConvTestParams p = {}; + p.batch_sz = N; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(N * 1 * 3 * 3); + const size_t offset_size = static_cast(N * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(N * 1 * 2 * 2 * out_h * out_w); + const size_t y_size = static_cast(N * 1 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + std::vector expected_Y(y_size, 0.04f); // 4 * 0.1 * 0.1 per position + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// group=1, offset_group=2: weights not grouped, offset/mask grouped. +TEST(DeformConvTest, Group1OffsetGroup2) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; // C must be divisible by offset_group + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(1 * 4 * 3 * 3); + const size_t w_size = static_cast(2 * 4 * 2 * 2); + const size_t offset_size = static_cast(1 * 2 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 2 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f, 0.f}; + // group=1: full conv. Each output: 4 in_ch * 4 kernel = 16 * 0.01 = 0.16 per channel, 2 out ch -> 0.16 each + std::vector expected_Y(static_cast(1 * 2 * out_h * out_w), 0.16f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Mask with zeros: exercises CUDA early-exit when mask_val == 0. +TEST(DeformConvTest, MaskWithZeros) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 1 * 3 * 3, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + std::vector offset(offset_size, 0.f); + // mask: (1, 4, 2, 2). Set all to 0 -> output should be 0. + std::vector mask(static_cast(1 * 1 * 2 * 2 * out_h * out_w), 0.f); + std::vector B = {0.f}; + std::vector expected_Y(static_cast(out_h * out_w), 0.f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Extreme aspect ratio (1x100): thin horizontal strip to verify coordinate indexing. +TEST(DeformConvTest, ExtremeAspectRatio) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 3}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 1; + p.in_w = 100; + // out_h = 1, out_w = (100 - 1*(3-1) - 1)/1 + 1 = 98 + const int64_t out_h = 1; + const int64_t out_w = 98; + + std::vector X(100, 0.1f); + std::vector W(1 * 1 * 1 * 3, 0.1f); + const size_t offset_size = static_cast(1 * 1 * 2 * 1 * 3 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 1 * 3 * out_h * out_w); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Each output: 3 * 0.1 * 0.1 = 0.03 + std::vector expected_Y(static_cast(out_h * out_w), 0.03f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// ONNX model data test: deform_conv_test_gen.py builds the ONNX model (via onnx.helper) +// and generates fixed inputs from torchvision (seed=123). This test is a model-loading/ +// integration smoke test that uses ORT-generated outputs from deform_conv_test.onnx as the reference. +TEST(DeformConvTest, OnnxModelTest) { + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", std::vector{2, 2}); + test.AddAttribute("strides", std::vector{1, 1}); + test.AddAttribute("pads", std::vector{0, 0, 0, 0}); + test.AddAttribute("dilations", std::vector{1, 1}); + test.AddAttribute("group", static_cast(2)); + test.AddAttribute("offset_group", static_cast(2)); + + test.AddInput("X", {1, 4, 3, 3}, kDeformConvOnnxTest_X); + test.AddInput("W", {2, 2, 2, 2}, kDeformConvOnnxTest_W); + test.AddInput("offset", {1, 16, 2, 2}, kDeformConvOnnxTest_offset); + test.AddInput("B", {2}, kDeformConvOnnxTest_B); + test.AddInput("mask", {1, 8, 2, 2}, kDeformConvOnnxTest_mask); + test.AddReferenceOutputs("testdata/deform_conv_test.onnx", 1e-4f); + + std::unordered_set excluded = {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/deform_conv_test.onnx b/onnxruntime/test/testdata/deform_conv_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b643014e44acbe4ee5d557f5fdf1b7eeca1951f0 GIT binary patch literal 353 zcmZ9IL2kk@5Jed`CGJ-W!L+DSrK;*OM_^@9j(|kT2BnHaBOs3Aw0G%_Q*j>b(m+7g z-^~1Zf5!ZNyl40&&8|Nz#t| zD8TKi(%sv5lL zo#SW9)bX=jRgCb!NrYgWtURk5C)b>}n#>kYieH=aS`IhvFn_MNZx0s$w`|8`@yq`= VT;}m+;M3+Uu4t#ciHA-&JOF_iJKX>P literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/deform_conv_test_data.inc b/onnxruntime/test/testdata/deform_conv_test_data.inc new file mode 100644 index 0000000000000..206d8517dd3e3 --- /dev/null +++ b/onnxruntime/test/testdata/deform_conv_test_data.inc @@ -0,0 +1,10 @@ +// Auto-generated by deform_conv_test_gen.py - do not edit + +#include + +static const std::vector kDeformConvOnnxTest_X = {0.296111941f, 0.516562283f, 0.251670718f, 0.68855679f, 0.0739724636f, 0.866521955f, 0.136579871f, 0.102479041f, 0.184056461f, 0.726446748f, 0.315253913f, 0.687106669f, 0.075635314f, 0.196638167f, 0.316411972f, 0.401740134f, 0.118568301f, 0.82739538f, 0.382084429f, 0.660493851f, 0.853571773f, 0.593153f, 0.636725366f, 0.982629359f, 0.274495304f, 0.658375621f, 0.277541935f, 0.857324839f, 0.899328232f, 0.0390138626f, 0.926822901f, 0.738757193f, 0.717883527f, 0.705837429f, 0.915649533f, 0.433980227f}; +static const std::vector kDeformConvOnnxTest_W = {-1.18204546f, -0.287744999f, -0.604300678f, 0.600236714f, -1.42047262f, -0.223827749f, 0.430554837f, -0.89885664f, -0.0178579595f, 0.426403075f, -0.765740693f, -0.0545141846f, -0.732052684f, 1.23474216f, 1.18622088f, -0.220098898f}; +static const std::vector kDeformConvOnnxTest_offset = {-0.388483077f, -0.934345901f, -0.499144107f, -1.08665264f, 0.962421f, 0.249208495f, -0.484502077f, -2.09291434f, 0.0982837752f, -0.0935074314f, 0.266214728f, -0.585035503f, -0.343037993f, -0.682147384f, -0.988689423f, -1.70183039f, -1.2202903f, 1.31385386f, 1.05329967f, 0.138805181f, -0.204444751f, -2.26852894f, -0.913327932f, -0.420362711f, -0.659559608f, -0.797927678f, 0.18383126f, 0.229347408f, 0.617742658f, -0.287577927f, 0.821824312f, 0.151177585f, -0.0443819836f, 1.62355745f, -2.32287097f, 1.08783054f, -0.0635453761f, -0.448640704f, -1.27846932f, -1.14400387f, -0.152640373f, 0.116741188f, 0.44026047f, -1.44654655f, -0.558081627f, -0.0516963229f, -0.90827328f, 0.350683212f, -0.394808769f, 0.489227712f, -0.216814891f, -1.74716449f, 1.72284174f, 0.773806036f, 0.404629797f, -1.64612663f, -0.59508425f, -0.711217523f, 0.622964859f, -1.37288189f, -0.128064156f, -1.28383458f, -0.290120065f, 1.27674019f}; +static const std::vector kDeformConvOnnxTest_B = {0.983955026f, 0.204511523f}; +static const std::vector kDeformConvOnnxTest_mask = {-0.0318612382f, -0.478955716f, 0.766808629f, 0.0274681915f, 0.0474699028f, -0.92386651f, -1.06073678f, -2.32444572f, -2.06281757f, 0.00637452863f, -0.989554703f, 0.701609194f, -0.982237995f, 0.277030349f, 0.645495057f, -0.895680785f, 0.492752999f, -0.0140781598f, -0.274662733f, -0.764091492f, -0.58715719f, 1.1951654f, -1.20957518f, -0.556007624f, -0.0771045536f, 1.27737665f, -1.45962942f, -2.15952778f, -0.70670861f, -0.92224431f, 3.89537215f, -0.602696717f}; +static const std::vector kDeformConvOnnxTest_expected_Y = {0.971546292f, 1.1398586f, 0.452816963f, 1.86388242f, -0.565265715f, 1.42318761f, -2.46283293f, -0.104923099f}; diff --git a/onnxruntime/test/testdata/deform_conv_test_data.npz b/onnxruntime/test/testdata/deform_conv_test_data.npz new file mode 100644 index 0000000000000000000000000000000000000000..68639f753501dc778a4a237855d3266f25edd252 GIT binary patch literal 2092 zcmbtVe@qj16u;8qNJkpBR75sMH&}|ZfVqaMW;|J^MJ>TzwhLI;p#rQQkp(uuAKG49Q8K;P5DjYhdPS9mL zHxrgRO(r{f$?t>Hba;O<)YyB7!hbn@ad-!e-&_I~Dj1F*Erw4kt4U&2D9kKOkoC`}qB$DYulHa^vH^{sX2 z<9ibPhw7oxXk=W|9@iv}3293?_E^kjt0`Z|-R7i_U_Hv6qfkWjeEtheUb^A-B}Nxv z!D}BpWGqXnP-85B!>isSt$q)fEz{ntX!8|%(1wl#eFeF7aWQ^=x}AwW6^3kJ3D#Zz z0gED=aF?nT$_L6}`UM&JDt|-Nr5VrV)ZjPzXBmc&CuwwRt((HGeQQdM*4d znTYAN#L{JlvQTsB2%hVavubY%n>EqD?(O6fY-@^w`_^K-Gjx=^1TVrioe6{DXUXDL z^MwX`HDhax!+Tr&G3=iPcw_JcmTi;af4_X}ke4nK zABe7n7k7l03dqSg zKNPR{5@eq&WA981ISe|H*GkIuP~*=nUcH^U@MeAx%^h*>F8lyzxh1^1x@_V$nH%Ts zF8uK0cjQ&dxd~zu5ijboQRwc+Hy+ [pad_h, pad_w, pad_h, pad_w] + pads = [PAD_H, PAD_W, PAD_H, PAD_W] + + node = helper.make_node( + "DeformConv", + inputs=["X", "W", "offset", "B", "mask"], + outputs=["Y"], + kernel_shape=[KH, KW], + strides=[STRIDE_H, STRIDE_W], + pads=pads, + dilations=[DIL_H, DIL_W], + group=N_WEIGHT_GRPS, + offset_group=N_OFFSET_GRPS, + ) + + graph = helper.make_graph( + [node], + "DeformConvTest", + [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, [BATCH, N_IN, IN_H, IN_W]), + helper.make_tensor_value_info("W", TensorProto.FLOAT, [N_OUT, N_IN // N_WEIGHT_GRPS, KH, KW]), + helper.make_tensor_value_info( + "offset", TensorProto.FLOAT, [BATCH, N_OFFSET_GRPS * 2 * KH * KW, OUT_H, OUT_W] + ), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [N_OUT]), + helper.make_tensor_value_info("mask", TensorProto.FLOAT, [BATCH, N_OFFSET_GRPS * KH * KW, OUT_H, OUT_W]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [BATCH, N_OUT, OUT_H, OUT_W])], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 19)]) + checker.check_model(model) + return model + + +def _to_cpp_array(name: str, arr: np.ndarray) -> str: + """Format numpy array as C++ initializer list.""" + flat = arr.flatten().tolist() + vals = ", ".join(f"{x:.9g}f" for x in flat) + return f"static const std::vector {name} = {{{vals}}};" + + +def _write_cpp_inc(data: dict, inc_path: Path) -> None: + """Write C++ include file with test data.""" + lines = [ + "// Auto-generated by deform_conv_test_gen.py - do not edit", + "", + "#include ", + "", + _to_cpp_array("kDeformConvOnnxTest_X", data["X"]), + _to_cpp_array("kDeformConvOnnxTest_W", data["W"]), + _to_cpp_array("kDeformConvOnnxTest_offset", data["offset"]), + _to_cpp_array("kDeformConvOnnxTest_B", data["B"]), + _to_cpp_array("kDeformConvOnnxTest_mask", data["mask"]), + _to_cpp_array("kDeformConvOnnxTest_expected_Y", data["expected_Y"]), + "", + ] + inc_path.write_text("\n".join(lines), encoding="utf-8") + + +def main(): + # Output to testdata/ root (same as layernorm.onnx, attention_past_state.onnx, etc.) + script_dir = Path(__file__).resolve().parent + assert script_dir.name == "nn", "Script must live in testdata/nn/" + testdata_root = script_dir.parent + model_path = testdata_root / "deform_conv_test.onnx" + data_path = testdata_root / "deform_conv_test_data.npz" + inc_path = testdata_root / "deform_conv_test_data.inc" + + print("Generating reference via torchvision.ops.deform_conv2d...") + data = _generate_reference() + + print("Building ONNX model...") + model = _build_onnx_model() + save(model, str(model_path)) + print(f" Saved {model_path}") + + np.savez(str(data_path), **data) + print(f" Saved {data_path}") + + _write_cpp_inc(data, inc_path) + print(f" Saved {inc_path}") + + # Validate with onnxruntime if available + if ort is not None: + print("Validating with ONNX Runtime...") + sess = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + ort_out = sess.run( + ["Y"], + { + "X": data["X"], + "W": data["W"], + "offset": data["offset"], + "B": data["B"], + "mask": data["mask"], + }, + )[0] + + rtol, atol = 1e-4, 1e-4 + if np.allclose(ort_out, data["expected_Y"], rtol=rtol, atol=atol): + print(" PASS: ORT output matches reference.") + else: + diff = np.abs(ort_out.astype(np.float64) - data["expected_Y"].astype(np.float64)) + print(f" FAIL: max |diff|={diff.max()}, mean={diff.mean()}") + else: + print(" (onnxruntime not installed; skip validation)") + + +if __name__ == "__main__": + main() From 3bb032e8151fbc93a154a96aa8616b390b066dba Mon Sep 17 00:00:00 2001 From: Sagar Date: Mon, 23 Mar 2026 23:50:43 +0530 Subject: [PATCH 02/17] Fix non-ASCII Unicode model path crash across session and provider code (#27724) ### Description On Windows, `std::filesystem::path::string()` converts the internal UTF-16 representation to a narrow string using the system's active ANSI code page. When the path contains characters outside that code page (Japanese, Chinese, Korean, etc.), this throws std::system_error with 'No mapping for the Unicode character exists in the target multi-byte code page.' This affected both the core session telemetry logging (causing `InferenceSession::Initialize()` to fail) and execution provider code (OpenVINO, TensorRT, TensorRT RTX, QNN, MIGraphX) where model paths are converted for EPContext attributes and profiling. ### Motivation and Context Fix: Replace `.filename().string()` with `PathToUTF8String(.filename().native())` which uses `WideCharToMultiByte(CP_UTF8, ...)` and handles all Unicode characters correctly. This pattern is already used elsewhere in the codebase for path-to-string conversions. Note: Two remaining instances in Linux-only code (`cann_utils.cc`, `device_discovery.cc`) are left as-is since `.string()` is safe on Linux where paths are already narrow strings. Fixes microsoft/WindowsAppSDK#6173 --------- Co-authored-by: Sagar Bhure Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../providers/migraphx/migraphx_execution_provider_utils.h | 2 +- .../core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc | 4 ++-- onnxruntime/core/providers/openvino/backend_manager.cc | 2 +- .../core/providers/qnn/builder/onnx_ctx_model_helper.cc | 2 +- .../core/providers/qnn/builder/qnn_profile_serializer.cc | 2 +- onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc | 2 +- onnxruntime/core/session/environment.cc | 2 +- onnxruntime/core/session/inference_session.cc | 4 ++-- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index cce90f3ef82be..e20cc9140916a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -268,7 +268,7 @@ inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { const fs::path path{main_graph.ModelPath()}; if (path.has_filename()) { - const auto model_name = path.filename().string(); + const auto model_name = PathToUTF8String(path.filename().native()); LOGS_DEFAULT(INFO) << "Model name is '" << model_name << "'"; // Ensure enough characters are hashed in case model names are too short diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 7baac6aa1f6d0..0b137c4674b00 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -166,7 +166,7 @@ Status CreateCtxNode(const GraphViewer& graph_viewer, } attr_ep_cache_context->set_s(engine_data_str); } else { - std::string engine_cache_filename = std::filesystem::path(engine_cache_path).filename().string(); + std::string engine_cache_filename = PathToUTF8String(std::filesystem::path(engine_cache_path).filename().native()); attr_ep_cache_context->set_s(engine_cache_filename); std::fstream engine_cache_file(engine_cache_path, std::ios::binary | std::ios::out); if (engine_cache_file.is_open()) { @@ -188,7 +188,7 @@ Status CreateCtxNode(const GraphViewer& graph_viewer, attr_onnx_filename->set_name(ONNX_MODEL_FILENAME); attr_onnx_filename->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_onnx_filename->set_s(std::filesystem::path(onnx_model_path).filename().string()); + attr_onnx_filename->set_s(PathToUTF8String(std::filesystem::path(onnx_model_path).filename().native())); attr_sdk_version->set_name(SDK_VERSION); attr_sdk_version->set_type(onnx::AttributeProto_AttributeType_STRING); diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 892bdec7abe83..c033b0b10a786 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -185,7 +185,7 @@ void BackendManager::TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVi model_blob_str = std::move(ss).str(); } } else { // External blob - model_blob_str = shared_context_.GetBinPath().filename().string(); + model_blob_str = PathToUTF8String(shared_context_.GetBinPath().filename().native()); } auto status = ep_ctx_handle_.AddOVEPCtxNodeToGraph(graph_body_viewer, diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 0e49c0f897bea..68aa9a157f4a2 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -269,7 +269,7 @@ Status CreateEPContextNodes(Model* model, } context_bin_path = context_bin_path + ToPathString("_qnn.bin"); - context_cache_name = std::filesystem::path(context_bin_path).filename().string(); + context_cache_name = PathToUTF8String(std::filesystem::path(context_bin_path).filename().native()); // If generate ctx.onnx with share_ep_context enabled, all ctx.onnx should point to the same ctx.bin if (share_ep_contexts) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc b/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc index fb76f2110cbc8..87340e5b3ebeb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc @@ -237,7 +237,7 @@ Serializer::Serializer(const ProfilingInfo& profiling_info, tracelogging_provider_ep_enabled_(tracelogging_provider_ep_enabled) { #ifdef QNN_SYSTEM_PROFILE_API_ENABLED std::filesystem::path output_fs_filepath(profiling_info.csv_output_filepath); - qnn_log_filename_ = output_fs_filepath.filename().string(); + qnn_log_filename_ = PathToUTF8String(output_fs_filepath.filename().native()); // Remove extension (assumed to be ".csv") then add "_qnn.log" size_t extension_start_idx = qnn_log_filename_.rfind("."); qnn_log_filename_ = qnn_log_filename_.substr(0, extension_start_idx); diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index 2a54bfea86e91..091b110d8c746 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -120,7 +120,7 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, attr_2->set_s(compute_capability); attr_3->set_name(ONNX_MODEL_FILENAME); attr_3->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_3->set_s(std::filesystem::path(onnx_model_path).filename().string()); + attr_3->set_s(PathToUTF8String(std::filesystem::path(onnx_model_path).filename().native())); attr_4->set_name(SOURCE); attr_4->set_type(onnx::AttributeProto_AttributeType_STRING); attr_4->set_s(kTensorrtExecutionProvider); diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 7cd02e5413407..f37c685cf2f28 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -618,7 +618,7 @@ Status Environment::CreateAndRegisterInternalEps() { Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path) { std::lock_guard lock{mutex_}; - std::string lib_file_name = std::filesystem::path(lib_path).filename().string(); + std::string lib_file_name = PathToUTF8String(std::filesystem::path(lib_path).filename().native()); Env::Default().GetTelemetryProvider().LogRegisterEpLibraryWithLibPath(registration_name, lib_file_name); std::vector internal_factories = {}; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 08b58f3de1a11..b873c95b496bb 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2612,7 +2612,7 @@ common::Status InferenceSession::Initialize() { // and log telemetry std::filesystem::path model_path = graph.ModelPath(); - std::string model_file_name = model_path.filename().string(); + std::string model_file_name = PathToUTF8String(model_path.filename().native()); bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); env.GetTelemetryProvider().LogSessionCreation( session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), @@ -4096,7 +4096,7 @@ void InferenceSession::LogAllSessions() { if (nullptr != model) { onnxruntime::Graph& graph = model->MainGraph(); std::filesystem::path model_path = graph.ModelPath(); - std::string model_file_name = model_path.filename().string(); + std::string model_file_name = PathToUTF8String(model_path.filename().native()); bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); std::string model_weight_type = session->GetWeightDataType(); std::string model_graph_hash = session->GetGraphHash(); From 45b5900d107dd45e48285a3d77ce275e8f6528ec Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 23 Mar 2026 11:22:29 -0700 Subject: [PATCH 03/17] [CUDA] RoiAlign for opset versions 16 and 22 (#27646) Support RoiAlign for opset versions 16 and 22 --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/OperatorKernels.md | 4 +- .../providers/cpu/object_detection/roialign.h | 4 + .../providers/cuda/cuda_execution_provider.cc | 22 ++- .../cuda/object_detection/roialign.cc | 45 +++++- .../cuda/object_detection/roialign_impl.cu | 112 ++++++++------- .../cpu/object_detection/roialign_test.cc | 134 ++++++++++++++++++ 6 files changed, 262 insertions(+), 59 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 25829d08206cf..7dfa76193a5b0 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -906,7 +906,9 @@ Do not modify directly.* |||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| +|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|22+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(int64)| +|||[16, 21]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(int64)| +|||[10, 15]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |RotaryEmbedding|*in* X:**T**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**M**
*out* Y:**T**|23+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.h b/onnxruntime/core/providers/cpu/object_detection/roialign.h index bb97de158369b..4ce4825e1d78c 100644 --- a/onnxruntime/core/providers/cpu/object_detection/roialign.h +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.h @@ -129,6 +129,10 @@ class RoiAlignBase { std::string coordinate_transformation_mode; if (info.template GetAttr("coordinate_transformation_mode", &coordinate_transformation_mode).IsOK()) { half_pixel_ = coordinate_transformation_mode == "half_pixel"; + } else { + // For opset 16+, the default is "half_pixel" per ONNX spec. + // For opset 10 (which has no coordinate_transformation_mode attribute), false is correct. + half_pixel_ = info.node().SinceVersion() >= 16; } if (mode_ == RoiAlignMode::max && sampling_ratio_ != 1) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 4f36789191bb2..0808a7110f242 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -944,8 +944,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, ReverseSequence); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, float, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, MLFloat16, RoiAlign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, ThresholdedRelu); @@ -1615,6 +1618,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GRU); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, GRU); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, GRU); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, RoiAlign); // Opset 23. class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention); @@ -2070,8 +2077,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2742,6 +2752,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 23 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign.cc b/onnxruntime/core/providers/cuda/object_detection/roialign.cc index 71fb066c2898f..5d876ae5a2cc9 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cuda/object_detection/roialign.cc @@ -7,11 +7,37 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ +#define ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ RoiAlign, \ kOnnxDomain, \ 10, \ + 15, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + RoiAlign); + +#define ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + RoiAlign, \ + kOnnxDomain, \ + 16, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + RoiAlign); + +#define ADD_TYPED_ROIALIGN_OP_22(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RoiAlign, \ + kOnnxDomain, \ + 22, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -67,13 +93,22 @@ Status RoiAlign::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -#define SPECIALIZED_COMPUTE(T) \ - REGISTER_KERNEL_TYPED(T) \ +#define SPECIALIZED_COMPUTE(T) \ + ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \ + ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \ + ADD_TYPED_ROIALIGN_OP_22(T) \ template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; SPECIALIZED_COMPUTE(float) SPECIALIZED_COMPUTE(double) -// SPECIALIZED_COMPUTE(MLFloat16) +// MLFloat16 is available for RoiAlign op from version 16 (not version 10): +ADD_VERSIONED_TYPED_ROIALIGN_OP_16(MLFloat16) +ADD_TYPED_ROIALIGN_OP_22(MLFloat16) +template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; + +// BFloat16 is available for RoiAlign op from version 22: +ADD_TYPED_ROIALIGN_OP_22(BFloat16) +template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; } // namespace cuda }; // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu index 7acfd9d075461..87f4aba8e45b2 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu +++ b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu @@ -17,64 +17,72 @@ #include "roialign_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/accumulation_type.h" namespace onnxruntime { namespace cuda { template -__device__ T bilinear_interpolate( +__device__ AccumulationType_t bilinear_interpolate( const T* bottom_data, const int height, const int width, - T y, - T x, + AccumulationType_t y, + AccumulationType_t x, const bool is_mode_avg, const int index /* index for debug only*/) { + using TAcc = AccumulationType_t; + // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { + if (y < static_cast(-1.0f) || y > static_cast(height) || + x < static_cast(-1.0f) || x > static_cast(width)) { // empty - return 0; + return static_cast(0.0f); } - if (y <= 0) { - y = 0; + if (y <= static_cast(0.0f)) { + y = static_cast(0.0f); } - if (x <= 0) { - x = 0; + if (x <= static_cast(0.0f)) { + x = static_cast(0.0f); } - int y_low = (int)y; - int x_low = (int)x; + int y_low = static_cast(y); + int x_low = static_cast(x); int y_high; int x_high; if (y_low >= height - 1) { y_high = y_low = height - 1; - y = (T)y_low; + y = static_cast(y_low); } else { y_high = y_low + 1; } if (x_low >= width - 1) { x_high = x_low = width - 1; - x = (T)x_low; + x = static_cast(x_low); } else { x_high = x_low + 1; } - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; + TAcc ly = y - static_cast(y_low); + TAcc lx = x - static_cast(x_low); + TAcc hy = static_cast(1.0f) - ly; + TAcc hx = static_cast(1.0f) - lx; // do bilinear interpolation - T v1 = bottom_data[y_low * width + x_low]; - T v2 = bottom_data[y_low * width + x_high]; - T v3 = bottom_data[y_high * width + x_low]; - T v4 = bottom_data[y_high * width + x_high]; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + TAcc v1 = static_cast(bottom_data[y_low * width + x_low]); + TAcc v2 = static_cast(bottom_data[y_low * width + x_high]); + TAcc v3 = static_cast(bottom_data[y_high * width + x_low]); + TAcc v4 = static_cast(bottom_data[y_high * width + x_high]); + TAcc w1 = hy * hx; + TAcc w2 = hy * lx; + TAcc w3 = ly * hx; + TAcc w4 = ly * lx; - T val = is_mode_avg - ? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg - : max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max + TAcc val = is_mode_avg + ? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg + : max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max return val; } @@ -97,6 +105,8 @@ __global__ void RoIAlignForward( const bool half_pixel, const int64_t* batch_indices_ptr, const int64_t batch_size) { + using TAcc = AccumulationType_t; + for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -111,26 +121,27 @@ __global__ void RoIAlignForward( // If the index is out of range, we set the output to 0 for this RoI element. if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) { CUDA_KERNEL_ASSERT(false && "batch_indices values are out of range"); - top_data[index] = 0; + top_data[index] = static_cast(0.0f); continue; } // Do not using rounding; this implementation detail is critical - T roi_offset = half_pixel ? T(0.5) : T(0); - T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset; - T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset; - T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset; - T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; + const TAcc spatial_scale_acc = static_cast(spatial_scale); + const TAcc roi_offset = half_pixel ? static_cast(0.5f) : static_cast(0.0f); + TAcc roi_start_w = static_cast(offset_bottom_rois[0]) * spatial_scale_acc - roi_offset; + TAcc roi_start_h = static_cast(offset_bottom_rois[1]) * spatial_scale_acc - roi_offset; + TAcc roi_end_w = static_cast(offset_bottom_rois[2]) * spatial_scale_acc - roi_offset; + TAcc roi_end_h = static_cast(offset_bottom_rois[3]) * spatial_scale_acc - roi_offset; + + TAcc roi_width = roi_end_w - roi_start_w; + TAcc roi_height = roi_end_h - roi_start_h; if (!half_pixel) { // backward compatibility // Force malformed ROIs to be 1x1 - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); + roi_width = max(roi_width, static_cast(1.0f)); + roi_height = max(roi_height, static_cast(1.0f)); } - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + const TAcc bin_size_h = roi_height / static_cast(pooled_height); + const TAcc bin_size_w = roi_width / static_cast(pooled_width); const T* offset_bottom_data = bottom_data + static_cast((roi_batch_ind * channels + c) * height * width); @@ -138,26 +149,27 @@ __global__ void RoIAlignForward( // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio - : _Ceil(roi_height / pooled_height); // e.g., = 2 + : static_cast(_Ceil(roi_height / static_cast(pooled_height))); // e.g., = 2 int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : _Ceil(roi_width / pooled_width); + (sampling_ratio > 0) ? sampling_ratio : static_cast(_Ceil(roi_width / static_cast(pooled_width))); // We do average (integral) pooling inside a bin - const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + const int grid_count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + const TAcc count = static_cast(grid_count); // e.g. = 4 - T output_val = 0.; + TAcc output_val = static_cast(0.0f); bool max_flag = false; for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + const TAcc y = roi_start_h + static_cast(ph) * bin_size_h + + (static_cast(iy) + static_cast(0.5f)) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); + const TAcc x = roi_start_w + static_cast(pw) * bin_size_w + + (static_cast(ix) + static_cast(0.5f)) * bin_size_w / + static_cast(roi_bin_grid_w); - T val = bilinear_interpolate( + const TAcc val = bilinear_interpolate( offset_bottom_data, height, width, y, x, is_mode_avg, index); if (is_mode_avg) { @@ -176,7 +188,7 @@ __global__ void RoIAlignForward( output_val /= count; } - top_data[index] = output_val; + top_data[index] = static_cast(output_val); } } @@ -241,6 +253,8 @@ void RoiAlignImpl( SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double) +SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc index 1eeb3683bc9aa..b2abe353693a2 100644 --- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" #include "test/common/trt_op_test_utils.h" @@ -906,5 +907,138 @@ TEST(RoiAlignTest, BatchIndicesNegative_CUDA) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); #endif } + +TEST(RoiAlignTest, Float16_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToFloat16({1.25f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoiAlignTest, Float16_Opset22) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 22); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToFloat16({1.25f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoiAlignTest, BFloat16_Opset22) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 22); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToBFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToBFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToBFloat16({1.25f})); + + test.SetOutputTolerance(0.05f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test half_pixel mode (default for Opset 16+) with Float16 on larger spatial dimensions. +// Uses 8x8 input (0..63), ROI [0,0,7,7], output 2x2, sampling_ratio=2. +// Expected values from ONNX reference implementation: {11.25, 14.75, 39.25, 42.75} +TEST(RoiAlignTest, Float16_HalfPixel_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + + std::vector X_val(64); + for (int i = 0; i < 64; ++i) X_val[i] = static_cast(i); + test.AddInput("X", {1, 1, 8, 8}, ToFloat16(X_val)); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 7., 7.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 2, 2}, ToFloat16({11.25f, 14.75f, 39.25f, 42.75f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test adaptive sampling (sampling_ratio=0) with Float16 on larger spatial dimensions. +// Uses 8x8 input (0..63), ROI [0,0,7,7], output 2x2, half_pixel mode. +// Adaptive: ceil(3.0/2)=2 samples per dim. +// Expected values from ONNX reference implementation: {11.39062, 14.875, 39.26562, 42.75} +TEST(RoiAlignTest, Float16_AdaptiveSampling_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 0); // adaptive + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + + std::vector X_val(64); + for (int i = 0; i < 64; ++i) X_val[i] = static_cast(i); + test.AddInput("X", {1, 1, 8, 8}, ToFloat16(X_val)); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 7., 7.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 2, 2}, + ToFloat16({11.39062f, 14.875f, 39.26562f, 42.75f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime From fde4e032f07169ab5c24c2409040a68d0e7accc2 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:23:48 -0700 Subject: [PATCH 04/17] [CUDA] Extend Pad support through opset 25 with wrap mode (#27774) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This PR consolidates PRs #27416 and #27708 to extend CUDA Pad kernel support through opset 25, including wrap mode implementation. ### Motivation and Context The CUDA execution provider previously only registered the Pad kernel up to opset 18 and did not implement wrap mode. When an ONNX model exported with opset 19+ was run on the CUDA executor, the Pad operation was forced to fall back to CPU, resulting in significant performance degradation. This PR aligns CUDA Pad registration with the ONNX Pad schema evolution through opset 25 and provides a correct wrap mode implementation. Related issues: https://github.com/microsoft/onnxruntime/issues/26393 Related PRs: #27416, #27708 ### Summary of Changes #### Kernel registration and opset coverage | File | Change | |------|--------| | `onnxruntime/core/providers/cuda/tensor/pad.cc` | Adds CUDA Pad kernel registrations for opset ranges 18, 19-20, 21-22, 23, 24, and 25. | | `onnxruntime/core/providers/cuda/cuda_execution_provider.cc` | Registers the new Pad kernel versions in the CUDA EP registry under the existing per-opset sections. | #### CUDA Pad implementation | File | Change | |------|--------| | `onnxruntime/core/providers/cuda/tensor/pad_impl.h` | Extends the Pad kernel interface to pass effective sliced extents and per-axis input offsets. | | `onnxruntime/core/providers/cuda/tensor/pad_impl.cu` | Adds CUDA wrap mode using a `WrapCoordinate` device helper with `if constexpr` compile-time specialization. Removes dead wrap code from the NCHW-specialized kernel path. | | `onnxruntime/core/providers/cuda/tensor/pad.cc` | Computes effective sliced input extents/offsets for wrap behavior with negative pads. Bypasses the NCHW fast-path for wrap mode and routes through the generic implementation. | #### Documentation | File | Change | |------|--------| | `docs/OperatorKernels.md` | Updates the CUDA Pad kernel opset coverage to reflect the new version splits (25+, 24, 23, [21,22], [19,20], 18) up to opset 25. | #### Test coverage | File | Change | |------|--------| | `onnxruntime/test/providers/cpu/tensor/pad_test.cc` | Adds CUDA-only Pad coverage for `edge` across opsets 18-25 and `wrap` across opsets 19-25. Updates existing wrap test comment. | ### Checklist - [x] Tests added/updated - [x] No breaking changes --- ✨ Let Copilot coding agent [set things up for you](https://github.com/microsoft/onnxruntime/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Co-authored-by: Shirasawa <764798966@qq.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu --- docs/OperatorKernels.md | 7 +- .../providers/cuda/cuda_execution_provider.cc | 56 ++++++++++-- onnxruntime/core/providers/cuda/tensor/pad.cc | 72 ++++++++++++++- .../core/providers/cuda/tensor/pad_impl.cu | 90 ++++++++++++------- .../core/providers/cuda/tensor/pad_impl.h | 2 + .../test/providers/cpu/tensor/pad_test.cc | 77 +++++++++++++++- 6 files changed, 259 insertions(+), 45 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 7dfa76193a5b0..625cc4e09ca13 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -847,7 +847,12 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|25+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||24|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||23|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[21, 22]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[19, 20]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||18|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 0808a7110f242..4c735fa2d5650 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1444,10 +1444,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, bool, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Resize); @@ -1455,6 +1455,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, uint8_t, Resize); // Opset 19 +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, bool, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, AveragePool); @@ -1579,6 +1583,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, bool, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, ConstantOfShape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Identity); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, If); @@ -1653,6 +1661,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E5M2, Cast); #endif +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Transpose); @@ -1677,10 +1689,18 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, bool, Pad); // Opset 25. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Squeeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Unsqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, ConstantOfShape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, If); @@ -2577,10 +2597,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2588,6 +2608,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 19-20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2661,6 +2685,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 21 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // TODO(fajin): support other quantized types BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2758,6 +2786,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 23 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2810,10 +2842,18 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 25 BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 9b23209953081..3dd50c1c03cbf 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -40,10 +40,70 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 18, 18, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 19, 20, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 21, 22, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 23, 23, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 24, 24, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Pad, \ kOnnxDomain, \ - 18, \ + 25, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -154,6 +214,11 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { effective_input_extents.push_back(extent); } + TArray input_offsets(dimension_count); + for (int32_t i = 0; i < dimension_count; ++i) { + input_offsets[i] = -(*p_slices)[i]; + } + TensorShape output_shape(output_dims); auto& output_tensor = *ctx->Output(0, output_shape); @@ -236,7 +301,8 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } - if (IsNCHWInputWithPaddingAlongHAndW(dimension_count, lower_pads, upper_pads)) { + if (mode_ != Mode::Wrap && + IsNCHWInputWithPaddingAlongHAndW(dimension_count, lower_pads, upper_pads)) { // If we have entered here, it means the input can only be 4-D (NCHW), 3-D (CHW), or 2-D (HW) // NCHW input @@ -282,6 +348,8 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { input_dims, input_strides, lower_pads, + TArray(effective_input_extents), + input_offsets, value, static_cast(mode_), reinterpret_cast::MappedType*>(input_tensor.Data()), diff --git a/onnxruntime/core/providers/cuda/tensor/pad_impl.cu b/onnxruntime/core/providers/cuda/tensor/pad_impl.cu index 6f530e800fdf2..6020769bf0ddf 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/pad_impl.cu @@ -7,19 +7,27 @@ namespace onnxruntime { namespace cuda { -// PadMode enum from core/providers/cpu/tensor/pad.h, cannot use that header because of nvcc/onnxruntime incompatibility +// PadMode enum from core/providers/cpu/tensor/padbase.h, cannot use that header because of nvcc/onnxruntime incompatibility enum class PadMode : int { Constant = 0, Reflect, - Edge + Edge, + Wrap }; +__device__ __forceinline__ int64_t WrapCoordinate(int64_t coord, int64_t extent) { + int64_t wrapped = coord % extent; + return wrapped < 0 ? wrapped + extent : wrapped; +} + template __global__ void _PadKernel( const size_t shape_rank, const TArray input_dims, const TArray input_strides, const TArray lower_pads, + const TArray effective_input_extents, + const TArray input_offsets, const T pad_value, const T* input_data, const TArray fdm_output_strides, @@ -33,33 +41,44 @@ __global__ void _PadKernel( int out_coord, r; fdm_output_strides[dim].divmod(output_index, out_coord, r); output_index = r; - int in_coord = 0; - if (out_coord < lower_pads[dim]) { - switch ((PadMode)pad_mode) { - case PadMode::Constant: - use_pad_value = true; - break; - case PadMode::Edge: - in_coord = 0; - break; - case PadMode::Reflect: - in_coord = lower_pads[dim] - out_coord; - break; - } - } else if (out_coord >= lower_pads[dim] + input_dims[dim]) { - switch ((PadMode)pad_mode) { - case PadMode::Constant: - use_pad_value = true; - break; - case PadMode::Edge: - in_coord = input_dims[dim] - 1; - break; - case PadMode::Reflect: - in_coord = input_dims[dim] - 2 - (out_coord - (lower_pads[dim] + input_dims[dim])); - break; - } + int64_t in_coord = 0; + if constexpr (pad_mode == static_cast(PadMode::Wrap)) { + const int64_t effective_input_extent = effective_input_extents[dim]; + const int64_t pre_pad = lower_pads[dim] + input_offsets[dim]; + const int64_t relative_coord = static_cast(out_coord) - pre_pad; + in_coord = input_offsets[dim] + WrapCoordinate(relative_coord, effective_input_extent); } else { - in_coord = out_coord - lower_pads[dim]; + if (out_coord < lower_pads[dim]) { + switch ((PadMode)pad_mode) { + case PadMode::Constant: + use_pad_value = true; + break; + case PadMode::Edge: + in_coord = 0; + break; + case PadMode::Reflect: + in_coord = lower_pads[dim] - out_coord; + break; + case PadMode::Wrap: + break; + } + } else if (out_coord >= lower_pads[dim] + input_dims[dim]) { + switch ((PadMode)pad_mode) { + case PadMode::Constant: + use_pad_value = true; + break; + case PadMode::Edge: + in_coord = input_dims[dim] - 1; + break; + case PadMode::Reflect: + in_coord = input_dims[dim] - 2 - (out_coord - (lower_pads[dim] + input_dims[dim])); + break; + case PadMode::Wrap: + break; + } + } else { + in_coord = out_coord - lower_pads[dim]; + } } input_index += input_strides[dim] * in_coord; } @@ -136,6 +155,8 @@ void PadImpl( const TArray& input_dims, const TArray& input_strides, const TArray& lower_pads, + const TArray& effective_input_extents, + const TArray& input_offsets, const T pad_value, const int pad_mode, const T* input_data, @@ -149,17 +170,22 @@ void PadImpl( switch (pad_mode) { case 0: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, pad_value, input_data, fdm_output_strides, output_data, N); break; case 1: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, pad_value, input_data, fdm_output_strides, output_data, N); break; case 2: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, + pad_value, input_data, fdm_output_strides, output_data, N); + break; + case 3: + _PadKernel<<>>( + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, pad_value, input_data, fdm_output_strides, output_data, N); break; } @@ -211,6 +237,8 @@ void PadNCHWInputWithPaddingAlongHAndWImpl( template void PadImpl(cudaStream_t stream, const size_t shape_rank, \ const TArray& input_dims, const TArray& input_strides, \ const TArray& lower_pads, \ + const TArray& effective_input_extents, \ + const TArray& input_offsets, \ const T pad_value, \ const int pad_mode, \ const T* input_data, \ diff --git a/onnxruntime/core/providers/cuda/tensor/pad_impl.h b/onnxruntime/core/providers/cuda/tensor/pad_impl.h index dc700ea2304e9..96f158dd187fc 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/pad_impl.h @@ -32,6 +32,8 @@ void PadImpl( const TArray& input_dims, const TArray& input_strides, const TArray& lower_pads, + const TArray& effective_input_extents, + const TArray& input_offsets, const T pad_value, const int pad_mode, const T* input_data, diff --git a/onnxruntime/test/providers/cpu/tensor/pad_test.cc b/onnxruntime/test/providers/cpu/tensor/pad_test.cc index 9169f2e6b5ca9..990e4354c3626 100644 --- a/onnxruntime/test/providers/cpu/tensor/pad_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/pad_test.cc @@ -124,6 +124,37 @@ static void RunAllOpsetAllDomainPadTests( } } +#ifdef USE_CUDA +template +static void RunCudaOnlyOnnxOpsetPadTest( + int opset, + const std::vector& input_dims, + const std::vector& input, + const std::vector& pads, + T value, + const std::vector& output_dims, + const std::vector& output, + const std::string& mode = "constant") { + auto cuda_execution_provider = DefaultCudaExecutionProvider(); + if (cuda_execution_provider == nullptr) { + GTEST_SKIP() << "CUDA execution provider is not available"; + } + + OpTester test("Pad", opset); + if (mode != "constant") { + test.AddAttribute("mode", mode); + } + test.AddInput("data", input_dims, input); + test.AddInput("pads", {static_cast(pads.size())}, pads, true); + test.AddInput("value", {}, {value}, true); + test.AddOutput("output", output_dims, output); + + std::vector> execution_providers; + execution_providers.emplace_back(std::move(cuda_execution_provider)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + // Some of the tests can't run on TensorrtExecutionProvider because only constant mode and value 0 of "Pad" node is supported. // Those tests will fallback to other EP. @@ -199,6 +230,48 @@ TYPED_TEST(PadOpTest, Pad_Edge_1D) { "edge"); } +#ifdef USE_CUDA +TEST(PadOpTest, Pad_Edge_CudaOnly_MLFloat16_SupportedOpsets) { + const std::vector supported_opsets{18, 19, 20, 21, 22, 23, 24, 25}; + for (int opset : supported_opsets) { + SCOPED_TRACE(MakeString("opset: ", opset)); + RunCudaOnlyOnnxOpsetPadTest( + opset, + {3, 2}, + {MLFloat16(1.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f)}, + {0, 2, 0, 1}, + MLFloat16(0.0f), + {3, 5}, + {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(3.0f), MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(5.0f), MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(6.0f)}, + "edge"); + } +} + +TEST(PadOpTest, Pad_Wrap_CudaOnly_Float_SupportedOpsets) { + const std::vector supported_opsets{19, 20, 21, 22, 23, 24, 25}; + for (int opset : supported_opsets) { + SCOPED_TRACE(MakeString("opset: ", opset)); + RunCudaOnlyOnnxOpsetPadTest( + opset, + {3, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f}, + {0, 1, 0, 1}, + 0.0f, + {3, 4}, + {2.0f, 1.0f, 2.0f, 1.0f, + 4.0f, 3.0f, 4.0f, 3.0f, + 6.0f, 5.0f, 6.0f, 5.0f}, + "wrap"); + } +} +#endif + TYPED_TEST(PadOpTest, Pad_Constant_2D) { using T = TypeParam; RunAllOpsetAllDomainPadTests({2, 2}, @@ -1391,9 +1464,7 @@ TEST(PadOpTest, Pad_Wrap_NegativeFront_PositiveBack) { // Post-slice core: [4]; wrap 3 -> [4, 4, 4, 4] const std::vector expected_data = {4, 4, 4, 4}; - // CUDA registers only up to 18 and does not impl wrap mode - // so we force version to 19 to automatically exclude EPs that do not - // implement wrap mode similar to the above tests. + // Use opset 19 to exercise wrap mode, which is supported from Pad-19 onward. OpTester test("Pad", 19); test.AddInput("data", input_shape, input_data); test.AddInput("pads", {static_cast(pads.size())}, pads, true); From 36c962fdc5c4b29011dd488f0b9a86df85a23e2b Mon Sep 17 00:00:00 2001 From: derdeljan-msft Date: Mon, 23 Mar 2026 21:15:06 +0100 Subject: [PATCH 05/17] Fix QNN SDK version propagation in Linux ort-qnn wheel build (#27800) ### Description QNN python wheel pipeline for Linux didn't take into account the QNN version that the user can override when running the pipeline and always used whatever was the default in the pipeline yamls. Allow for overriding QNN version in onnxruntime-qnn Linux python wheel pipelines. --- .../github/azure-pipelines/stages/py-cpu-packaging-stage.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index f767ef110561a..88d2981e2ccaa 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -194,6 +194,7 @@ stages: - template: ../templates/py-linux-qnn.yml parameters: machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' + QnnSdk: ${{ parameters.qnn_sdk_version }} extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} is1ES: true From c1719194acecaa2ca9f2250c218127e8a1c6246a Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Mon, 23 Mar 2026 15:27:33 -0700 Subject: [PATCH 06/17] Fix NeonFp16DequantB8Bit reference to match kernel fp16 precision (#27812) The kernel computes neg_scaled_zp = -(scale * zp) in fp16 first (intermediate rounding), then uses it in the fma. For scale*zp in range [128, 256), the fp16 ULP is 0.125, so this intermediate rounding error (~0.06) propagates to the result and exceeded the test tolerance. The reference was computing everything in fp32 and converting to fp16 only at the end, avoiding this intermediate rounding. This caused mismatches up to 0.15 (29 fp16 ULPs). Fix: emulate the kernel's fp16 computation order in the reference: 1. neg_szp = MLAS_FP16(-(scale * zp)).ToFloat() // fp16 round-trip 2. result = MLAS_FP16(neg_szp + value * scale) // emulates fma --- .../test/mlas/unittest/test_hqnbitgemm_neon.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp b/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp index 14e05fd42538e..3762e30af352d 100644 --- a/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp +++ b/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp @@ -586,9 +586,11 @@ class MlasNeonFp16DequantB8BitTest : public MlasTestBase { // Reference dequantization for 8-bit packed data. // Uses explicit position-based indexing to match the packed layout exactly. + // Emulates the kernel's fp16 computation order: + // 1. neg_scaled_zp = fp16_round(-(scale * zp)) [once per block per column] + // 2. result = fp16_round(neg_scaled_zp + value * scale) [emulates fp16 fma] // // Packed layout for N>=8 group (8N-interleaved): - // For each K position, 8 consecutive bytes hold one value per column. // byte[groupStart + k * 8 + col] = value for K=k, column=col // // Packed layout for remainder N<8 (sequential): @@ -610,12 +612,14 @@ class MlasNeonFp16DequantB8BitTest : public MlasTestBase { for (size_t col = 0; col < 8; ++col) { const size_t absCol = n + col; const size_t srcIdx = groupStart + k * 8 + col; - const size_t dstIdx = srcIdx; // output has the same interleaved layout + const size_t dstIdx = srcIdx; const float value = static_cast(src[srcIdx]); const float scale = scales[absCol * BlkNum + block].ToFloat(); const float zp = static_cast( UseZeroPoints ? zero_points[absCol * BlkNum + block] : 128); - dst[dstIdx] = MLAS_FP16(value * scale - zp * scale); + // Emulate kernel: neg_scaled_zp rounded to fp16, then fma + const float neg_szp = MLAS_FP16(-(scale * zp)).ToFloat(); + dst[dstIdx] = MLAS_FP16(neg_szp + value * scale); } } } @@ -631,7 +635,8 @@ class MlasNeonFp16DequantB8BitTest : public MlasTestBase { const float scale = scales[n * BlkNum + block].ToFloat(); const float zp = static_cast( UseZeroPoints ? zero_points[n * BlkNum + block] : 128); - dst[dstIdx] = MLAS_FP16(value * scale - zp * scale); + const float neg_szp = MLAS_FP16(-(scale * zp)).ToFloat(); + dst[dstIdx] = MLAS_FP16(neg_szp + value * scale); } } } From 0c3e5fcfb088634beddbe884eac38bb5659499a8 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 23 Mar 2026 15:30:04 -0700 Subject: [PATCH 07/17] fix webnn/where complicance tests for webgpu (#27776) webgpu ep will now pass https://wpt.live/webnn/conformance_tests/where.https.any.html?gpu --- onnxruntime/core/providers/webgpu/tensor/where.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index 3560fba522cb8..428fe863ab61b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -82,7 +82,7 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { -> void { const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; - const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; + const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << u32(component_c" + x + " * 8)))"; shader.MainFunctionBody() << "let output_indices" << x << " = " << output_indices.OffsetToIndices("global_idx * 4 + " + x) << ";\n" << "let offset_a" << x << " = " << a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" From 16b556d37cb74a056f58a4cd992ffb81c6e49674 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 24 Mar 2026 00:16:17 -0700 Subject: [PATCH 08/17] =?UTF-8?q?fix=20webnn=20test=20case=20for=20webgpu?= =?UTF-8?q?=20ep:=20'transpose=20float32=201D=20constant=20tensor=20defaul?= =?UTF-8?q?t=20op=E2=80=A6=20(#27773)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fixes a few webnn test cases for webgpu ep. This https://wpt.live/webnn/conformance_tests/transpose.https.any.html?gpu should pass 100% now --- onnxruntime/core/providers/webgpu/tensor/transpose.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 230d172d7404e..5cc09501ab378 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -244,6 +244,11 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { return Status::OK(); } + // 1D transpose is identity - just copy the GPU buffer. + if (rank == 1) { + return Info().GetDataTransferManager().CopyTensor(*input_tensor, *output_tensor); + } + return DoTranspose(context, *p_perm, *input_tensor, *output_tensor); } From 37b863c617af3900b28b9588c68f784090b8101c Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 24 Mar 2026 03:16:43 -0700 Subject: [PATCH 09/17] =?UTF-8?q?Extend=20DQ=E2=86=92MatMulNBits=20fusion?= =?UTF-8?q?=20to=20support=20Gemm=20+=20per-tensor/per-channel=20quantizat?= =?UTF-8?q?ion=20(#27769)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the QDQ `DQMatMulToMatMulNBits` fusion to handle additional quantization patterns beyond the existing blockwise DQ→MatMul case. ### New support - **Gemm**: Fuses DQ→Gemm (with optional bias, including DQ bias) into MatMulNBits, stripping Gemm-specific attributes (`alpha`, `beta`, `transB`). - **Per-tensor & per-channel quantization**: Expands scalar/1D scales and zero-points into block-quantized format expected by MatMulNBits. Block size is configurable via `session.qdq_matmulnbits_block_size` (default: 32). ### Changes - **Selectors** (qdq_selectors.cc): Replaced `ValidateBlockwiseDQForMatMulNBits` with `ValidateDQForMatMulNBits` supporting all three quantization modes. Added Gemm-specific validation. - **Actions** (qdq_actions.cc): Added scale/zp expansion for non-blockwise cases, Gemm attribute cleanup, and bias wiring to MatMulNBits input 5. - **Registration** (qdq_selector_action_transformer.cc): Registered `Gemm` alongside `MatMul`; threaded `qdq_matmulnbits_block_size` from session config. - **Tests** (qdq_matmulnbits_transformer_test.cc): Added tests for per-tensor, per-channel, Gemm (no bias, constant bias, DQ bias), block size options, and negative cases. --- .../onnxruntime_session_options_config_keys.h | 7 + .../core/optimizer/graph_transformer_utils.cc | 14 +- .../selectors_actions/qdq_actions.cc | 228 +++++- .../selectors_actions/qdq_actions.h | 4 +- .../qdq_selector_action_transformer.cc | 23 +- .../qdq_selector_action_transformer.h | 3 +- .../selectors_actions/qdq_selectors.cc | 206 +++++- .../selectors_actions/qdq_selectors.h | 6 +- .../qdq_matmulnbits_transformer_test.cc | 690 +++++++++++++++--- 9 files changed, 1033 insertions(+), 148 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index f0a99bc11c8b3..a9d9ac8323b16 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -391,6 +391,13 @@ static const char* const kOrtSessionOptionsMlasDisableKleidiAi = "mlas.disable_k // If not provided, default is 4. static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; +// Block size used when converting per-tensor or per-axis DQ + MatMul to MatMulNBits. +// Only applies to DQ nodes without an existing block_size attribute (i.e., per-tensor or per-axis quantization). +// Positive value: explicit block_size (must be power-of-2 and >= 16, e.g., 16, 32, 64, 128). +// "0" or not provided: use default block_size of 32. +// "-1": heuristic - largest power-of-2 <= min(K, 256) that minimizes padding. +static const char* const kOrtSessionOptionsQDQMatMulNBitsBlockSize = "session.qdq_matmulnbits_block_size"; + // Enable the DQ->MatMulNBits fusion graph transformer. // "0": disabled (default). "1": enabled. // This is typically set automatically by InferenceSession when the NvTensorRTRTX EP is registered. diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index ac712084012a4..9ed1d5e9e84fa 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -355,6 +355,10 @@ InlinedVector> GenerateTransformers( ParseStringWithClassicLocale( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "4")); + const int64_t qdq_matmulnbits_block_size = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsBlockSize, + "0")); #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -371,7 +375,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, SatApplyContextVariant{}, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + qdq_matmulnbits_block_size)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -513,6 +518,10 @@ InlinedVector> GenerateTransformersForMinimalB ParseStringWithClassicLocale( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "4")); + const int64_t qdq_matmulnbits_block_size = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsBlockSize, + "0")); // runtime optimizations only support CPU EP now const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; @@ -520,7 +529,8 @@ InlinedVector> GenerateTransformersForMinimalB transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + qdq_matmulnbits_block_size)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index fdc0818e8437b..b9d7e898157bd 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -42,6 +42,67 @@ bool IsDQWeightSigned(int32_t dt_weight) { dt_weight == TensorProto::INT8; } +// Compute the effective block_size for per-tensor/per-channel DQ nodes that lack a block_size attribute. +// session_block_size: 0 = default (32), positive = explicit, -1 = min-padding heuristic. +int64_t ComputeEffectiveBlockSize(int64_t session_block_size, int64_t K) { + // MatMulNBits CPU kernel currently only supports block_size in [16, 256] correctly. + constexpr int64_t kMinBlockSize = 16; + constexpr int64_t kMaxBlockSize = 256; + + if (session_block_size > 0) { + // Explicit block_size — must be power-of-2 and within [kMinBlockSize, kMaxBlockSize]. + ORT_ENFORCE(session_block_size >= kMinBlockSize && + ((session_block_size & (session_block_size - 1)) == 0), + "Explicit qdq_matmulnbits_block_size must be a power-of-2 and >= ", + kMinBlockSize, ", got: ", session_block_size); + ORT_ENFORCE(session_block_size <= kMaxBlockSize, + "Explicit qdq_matmulnbits_block_size must be <= ", + kMaxBlockSize, ", got: ", session_block_size); + return session_block_size; + } + + if (session_block_size == -1) { + // Heuristic: largest power-of-2 <= min(K, kMaxBlockSize) that minimizes padding. + // Capped at kMaxBlockSize because CPU EP only supports block_size up to kMaxBlockSize correctly. + // We want ceil(K / B) * B - K to be minimized (least wasted padding). + int64_t best_bs = kMinBlockSize; + int64_t best_padding = (((K + (kMinBlockSize - 1)) / kMinBlockSize) * kMinBlockSize) - K; + for (int64_t bs = kMinBlockSize * 2; bs <= std::min(K, kMaxBlockSize); bs *= 2) { + int64_t padding = (((K + bs - 1) / bs) * bs) - K; + if (padding <= best_padding) { + best_padding = padding; + best_bs = bs; + } + } + return best_bs; + } + + // Default (session_block_size == 0): use 32 + return 32; +} + +// Get the DQ block_size: from the attribute if blockwise, or computed for per-tensor/per-channel. +int64_t GetEffectiveBlockSize(const Node& dq_node, int64_t block_size_for_non_blockwise) { + const auto& dq_attrs = dq_node.GetAttributes(); + const auto bs_iter = dq_attrs.find("block_size"); + if (bs_iter != dq_attrs.end() && bs_iter->second.i() > 0) { + return bs_iter->second.i(); + } + + // Derive K from the weight input shape if available. Shape information may be missing even + // when the weight is a constant initializer, so guard against nullptrs / unknown dims. + int64_t K = 32; // reasonable default consistent with ComputeEffectiveBlockSize default + const auto* weight_arg = dq_node.InputDefs()[0]; + if (weight_arg != nullptr) { + const auto* shape = weight_arg->Shape(); + if (shape != nullptr && shape->dim_size() > 0 && shape->dim(0).has_dim_value()) { + K = static_cast(shape->dim(0).dim_value()); + } + } + + return ComputeEffectiveBlockSize(block_size_for_non_blockwise, K); +} + // Holds transposed weight/scale/zp tensors and their TensorProtos for MatMulNBits. // Used by DQMatMulToMatMulNBitsAction. struct TransposedQuantizedTensors { @@ -56,16 +117,17 @@ struct TransposedQuantizedTensors { // Transpose DQ weight/scale/zp tensors from column-wise layout to MatMulNBits layout via MLAS. // default_zp_name_prefix: prefix for auto-generated zero-point name when unsigned type has no explicit zp. +// effective_block_size: the block_size to use for MatMulNBits (may differ from DQ's block_size for per-tensor/per-channel). Status TransposeDQWeightsForMatMulNBits( Graph& graph, const Node& dq_node, const std::string& default_zp_name_prefix, concurrency::ThreadPool* intra_op_thread_pool, + int64_t effective_block_size, TransposedQuantizedTensors& result) { const auto* weight_arg = dq_node.InputDefs()[0]; const auto* scale_arg = dq_node.InputDefs()[1]; const auto* zp_arg = dq_node.InputDefs().size() > 2 ? dq_node.InputDefs()[2] : nullptr; - const auto& attrs = dq_node.GetAttributes(); const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), @@ -78,9 +140,11 @@ Status TransposeDQWeightsForMatMulNBits( graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); } - auto K = weight_arg->Shape()->dim(0).dim_value(); - auto N = weight_arg->Shape()->dim(1).dim_value(); - auto block_size = attrs.at("block_size").i(); + ORT_RETURN_IF_NOT(weight_tensor_proto->dims_size() >= 2, + "Weight tensor for node ", dq_node.Name(), " must be at least 2D."); + auto K = weight_tensor_proto->dims(0); + auto N = weight_tensor_proto->dims(1); + auto block_size = effective_block_size; int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); auto bits = DQWeightBits(dt_weight); auto quant_num = (K + block_size - 1) / block_size; @@ -94,8 +158,100 @@ Status TransposeDQWeightsForMatMulNBits( std::optional zp_src; auto cpu_allocator = CPUAllocator::DefaultInstance(); + // Determine if scale/zp need expansion from per-tensor/per-channel to blockwise [quant_num, N]. + const bool is_blockwise = (scale_tensor_proto->dims_size() == 2); + std::optional expanded_scale; + std::optional expanded_zp; + + if (!is_blockwise) { + // Expand scale to [quant_num, N] + expanded_scale.emplace(scale_type, TensorShape{quant_num, N}, cpu_allocator); + bool is_per_tensor = (scale_tensor_proto->dims_size() == 0); + + auto expand_scale = [&](auto* src_data, auto* dst_data) { + if (is_per_tensor) { + auto val = src_data[0]; + for (int64_t i = 0; i < quant_num * N; ++i) { + dst_data[i] = val; + } + } else { + // Per-channel: scale shape [N], replicate across quant_num blocks + for (int64_t b = 0; b < quant_num; ++b) { + for (int64_t n = 0; n < N; ++n) { + dst_data[b * N + n] = src_data[n]; + } + } + } + }; + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + expand_scale(scale_src.data(), expanded_scale->MutableData()); + } else { + expand_scale(scale_src.data(), expanded_scale->MutableData()); + } + + // Expand zp if present + if (zp_tensor_proto) { + zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); + // Allocate as uint8 with enough bytes to hold quant_num*N packed sub-byte elements. + int64_t expanded_zp_bytes = (quant_num * N * bits + 7) / 8; + expanded_zp.emplace(uint8_type, TensorShape{expanded_zp_bytes}, cpu_allocator); + + // For sub-byte types, the zp is packed in bytes. We need to expand element-wise. + // For 8-bit, each byte is one element. For 4-bit, 2 elements per byte. For 2-bit, 4 elements per byte. + const uint8_t* zp_bytes = zp_src->DataAsByteSpan().data(); + uint8_t* dst_zp_bytes = expanded_zp->MutableData(); + + auto get_element = [bits](const uint8_t* data, int64_t idx) -> uint8_t { + if (bits == 8) return data[idx]; + if (bits == 4) { + uint8_t byte = data[idx / 2]; + return (idx % 2 == 0) ? (byte & 0x0F) : (byte >> 4); + } + // bits == 2 + uint8_t byte = data[idx / 4]; + int shift = static_cast((idx % 4) * 2); + return (byte >> shift) & 0x03; + }; + + auto set_element = [bits](uint8_t* data, int64_t idx, uint8_t val) { + if (bits == 8) { + data[idx] = val; + return; + } + if (bits == 4) { + int64_t byte_idx = idx / 2; + if (idx % 2 == 0) { + data[byte_idx] = (data[byte_idx] & 0xF0) | (val & 0x0F); + } else { + data[byte_idx] = (data[byte_idx] & 0x0F) | ((val & 0x0F) << 4); + } + return; + } + // bits == 2 + int64_t byte_idx = idx / 4; + int shift = static_cast((idx % 4) * 2); + data[byte_idx] = (data[byte_idx] & ~(0x03 << shift)) | ((val & 0x03) << shift); + }; + + // Initialize expanded zp to 0 + memset(dst_zp_bytes, 0, expanded_zp->SizeInBytes()); + + for (int64_t b = 0; b < quant_num; ++b) { + for (int64_t n = 0; n < N; ++n) { + int64_t src_idx = is_per_tensor ? 0 : n; + uint8_t val = get_element(zp_bytes, src_idx); + set_element(dst_zp_bytes, b * N + n, val); + } + } + } + } + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); result.weight = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator); + // Zero-initialize: MLAS 4-bit transpose does not zero-pad when K < block_size, + // leaving uninitialized bytes in the last block's padding region. + memset(result.weight.MutableDataRaw(), 0, result.weight.SizeInBytes()); auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); auto scale_size = (TensorShape{N, quant_num}).Size(); @@ -104,7 +260,13 @@ Status TransposeDQWeightsForMatMulNBits( std::string zp_dst_name; auto zp_size = (TensorShape{N, (quant_num * bits + 7) / 8}).Size(); - if (zp_tensor_proto) { + if (!is_blockwise && expanded_zp.has_value()) { + // Per-tensor/per-channel path with expanded zero-point + zp_dst_name = graph.GenerateNodeArgName( + (zp_arg ? zp_arg->Name() : default_zp_name_prefix + "_zero_point") + "_T"); + result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + } else if (zp_tensor_proto) { + // Blockwise path with explicit zero-point zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); @@ -116,10 +278,15 @@ Status TransposeDQWeightsForMatMulNBits( // Dispatch MLAS transpose based on scale type, bits, and signedness. auto transpose = [&](auto* scale_data, auto* scale_dst_data) { - using ScaleType = std::remove_pointer_t; + using ScaleType = std::remove_const_t>; bool is_signed = IsDQWeightSigned(dt_weight); const uint8_t* src_w = weight_src.DataAsByteSpan().data(); - const uint8_t* src_zp = zp_src ? zp_src->DataAsByteSpan().data() : nullptr; + const uint8_t* src_zp = nullptr; + if (expanded_zp.has_value()) { + src_zp = expanded_zp->Data(); + } else if (zp_src.has_value()) { + src_zp = zp_src->DataAsByteSpan().data(); + } uint8_t* dst_w = result.weight.MutableData(); uint8_t* dst_zp = result.zero_point ? result.zero_point->MutableData() : nullptr; int K_int = static_cast(K); @@ -148,9 +315,11 @@ Status TransposeDQWeightsForMatMulNBits( }; if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - transpose(scale_src.data(), result.scale.MutableData()); + const float* s_data = expanded_scale.has_value() ? expanded_scale->Data() : scale_src.data(); + transpose(s_data, result.scale.MutableData()); } else { - transpose(scale_src.data(), result.scale.MutableData()); + const MLFloat16* s_data = expanded_scale.has_value() ? expanded_scale->Data() : scale_src.data(); + transpose(s_data, result.scale.MutableData()); } result.weight_proto = utils::TensorToTensorProto(result.weight, weight_dst_name, true); @@ -430,7 +599,8 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) + concurrency::ThreadPool* intra_op_thread_pool, + int64_t block_size_for_non_blockwise) : accuracy_level_{accuracy_level}, domain_{kMSDomain}, op_type_{"MatMulNBits"}, @@ -440,7 +610,8 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()}, - intra_op_thread_pool_{intra_op_thread_pool} { + intra_op_thread_pool_{intra_op_thread_pool}, + block_size_for_non_blockwise_{block_size_for_non_blockwise} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } @@ -449,15 +620,17 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) NodeAttributes extra_attributes; const auto* dq_node = runtime_state.selected_nodes.Input(0); - auto& attrs = dq_node->GetAttributes(); const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); + ORT_ENFORCE(weight_shape != nullptr && weight_shape->dim_size() >= 2, + "Weight shape unavailable for DQ node ", dq_node->Name()); utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); int32_t dt_weight = dq_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); utils::SetNodeAttribute(utils::MakeAttribute("bits", DQWeightBits(dt_weight)), extra_attributes); - utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); + int64_t effective_bs = GetEffectiveBlockSize(*dq_node, block_size_for_non_blockwise_); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", effective_bs), extra_attributes); return extra_attributes; } @@ -467,9 +640,11 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, Node& replacement_node) const { const auto* dq_node = selected_nodes.Input(0); + int64_t effective_bs = GetEffectiveBlockSize(*dq_node, block_size_for_non_blockwise_); + TransposedQuantizedTensors transposed; ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits( - graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, transposed)); + graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, effective_bs, transposed)); auto& input_defs = replacement_node.MutableInputDefs(); input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight))); @@ -483,6 +658,31 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, replacement_node.MutableInputArgsCount().push_back(1); } + // If the target was Gemm, strip Gemm-specific attributes from the replacement MatMulNBits node + // and wire the bias (if present) to MatMulNBits input 5. + const auto& target = selected_nodes.Target(); + if (target.OpType() == "Gemm") { + replacement_node.ClearAttribute("alpha"); + replacement_node.ClearAttribute("beta"); + replacement_node.ClearAttribute("transA"); + replacement_node.ClearAttribute("transB"); + + // Wire Gemm bias to MatMulNBits input 5 (bias slot). + // The bias can be a direct float tensor or the output of a DQ node. + const auto& target_inputs = target.InputDefs(); + if (target_inputs.size() > 2 && target_inputs[2] && target_inputs[2]->Exists()) { + // MatMulNBits input layout: 0:A, 1:B, 2:scales, 3:zp(opt), 4:g_idx(opt), 5:bias(opt) + // Pad with empty NodeArgs up to position 5. + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + while (input_defs.size() < 5) { + input_defs.push_back(&empty_arg); + replacement_node.MutableInputArgsCount().push_back(1); + } + input_defs.push_back(const_cast(target_inputs[2])); + replacement_node.MutableInputArgsCount().push_back(1); + } + } + return Status::OK(); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 02a8353707599..f0b1e17a7ffe0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -86,7 +86,8 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool); + concurrency::ThreadPool* intra_op_thread_pool, + int64_t block_size_for_non_blockwise = 0); private: std::string OpType(const RuntimeState&) const override { return op_type_; } @@ -105,6 +106,7 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; + const int64_t block_size_for_non_blockwise_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 8cab6911646f2..c88ae9b8c4782 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -296,15 +296,19 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + int64_t qdq_matmulnbits_block_size) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is 2/4/8-bit int (int2/uint2, int4/uint4, int8/uint8). DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. + // Also supports per-tensor and per-channel (axis=1) quantized DQ weights by expanding + // scales/zero-points to blockwise format using qdq_matmulnbits_block_size. const std::string action_name{"DQMatMulToMatMulNBits"}; std::unique_ptr action = std::make_unique(qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + qdq_matmulnbits_block_size); #if !defined(ORT_MINIMAL_BUILD) // Include "" (empty string) to match nodes not yet assigned to an EP. @@ -315,7 +319,8 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi std::vector providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider, ""}; std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, - {{"MatMul", {}}}, + {{"MatMul", {}}, + {"Gemm", {}}}, std::move(selector), std::move(action)); @@ -370,7 +375,8 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { SelectorActionRegistry CreateSelectorActionRegistry( bool is_int8_allowed, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + int64_t qdq_matmulnbits_block_size) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -384,7 +390,8 @@ SelectorActionRegistry CreateSelectorActionRegistry( WhereQDQRules(qdq_selector_action_registry); DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + qdq_matmulnbits_block_size); return qdq_selector_action_registry; } @@ -395,11 +402,13 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( bool is_int8_allowed, const SatApplyContextVariant& apply_context, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) + concurrency::ThreadPool* intra_op_thread_pool, + int64_t qdq_matmulnbits_block_size) : SelectorActionTransformer{ "QDQSelectorActionTransformer", CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool), + intra_op_thread_pool, + qdq_matmulnbits_block_size), apply_context, // this transformer is compatible with CPU, DML, ACL and CUDA EP. // There is further EP control on the rule level. diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index dce1cd44fd3ea..8294c839cfe42 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -29,7 +29,8 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer { QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}, int64_t qdq_matmulnbits_accuracy_level = 4, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + int64_t qdq_matmulnbits_block_size = 0); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 8a00fe11ff3fd..ef9e1b0cad490 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -6,6 +6,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/graph/graph.h" +#include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" @@ -558,11 +559,14 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& } } -// Validate that a DQ node has the correct structure for MatMulNBits fusion: -// - weight type is 2/4/8-bit int, scale type is float or float16 -// - blockwise quantization along axis 0, block_size is power-of-2 and >= 16 -// - weight/scale/zp are constant initializers with rank 2 and consistent shapes -static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq_node) { +// Validate that a DQ node has the correct structure for MatMulNBits fusion. +// Supports three quantization granularities: +// - Blockwise: axis=0, block_size >= 16 and power-of-2, scale/zp rank 2 +// - Per-tensor: scale is scalar (rank 0), no block_size attribute +// - Per-channel (axis=1): scale is 1D with shape [N], weight is 2D [K,N], no block_size attribute +// In all cases: weight type is 2/4/8-bit int, scale type is float or float16, +// weight/scale/zp are constant initializers. +static bool ValidateDQForMatMulNBits(const Graph& graph, const Node& dq_node) { const auto* weight_arg = dq_node.InputDefs()[0]; const auto* scale_arg = dq_node.InputDefs()[1]; const auto* zero_point_arg = dq_node.InputDefs().size() == 3 ? dq_node.InputDefs()[2] : nullptr; @@ -578,22 +582,6 @@ static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq return false; } - // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 - const auto& dq_attrs = dq_node.GetAttributes(); - if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { - return false; - } - - const auto a_iter = dq_attrs.find("block_size"); - if (a_iter == dq_attrs.end()) { - return false; - } - - auto block_size = a_iter->second.i(); - if (block_size < 16 || ((block_size - 1) & block_size)) { - return false; - } - // weight, scale and zero points (if exists) must be constants const auto* weight_tensor_proto = graph.GetConstantInitializer(weight_arg->Name(), true); const auto* scale_tensor_proto = graph.GetConstantInitializer(scale_arg->Name(), true); @@ -607,18 +595,124 @@ static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq return false; } - // weight, scale and zero points (if exists) must have the rank 2 - if (weight_tensor_proto->dims_size() != 2 || scale_tensor_proto->dims_size() != 2 || - (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + // weight must be rank 2 + if (weight_tensor_proto->dims_size() != 2) { return false; } - // check weight, scale and zero points (if exists) shapes - if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || - weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || - (zp_tensor_proto && (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || - zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + const auto& dq_attrs = dq_node.GetAttributes(); + const auto block_size_iter = dq_attrs.find("block_size"); + const bool has_block_size = block_size_iter != dq_attrs.end() && block_size_iter->second.i() > 0; + + if (has_block_size) { + // --- Blockwise path (existing logic) --- + if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { + return false; + } + + auto block_size = block_size_iter->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) { + return false; + } + + if (scale_tensor_proto->dims_size() != 2 || + (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + return false; + } + + if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || + weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || + (zp_tensor_proto && (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || + zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + return false; + } + } else { + // --- Per-tensor or per-channel path --- + int scale_rank = scale_tensor_proto->dims_size(); + auto N = weight_tensor_proto->dims()[1]; + + if (scale_rank == 0) { + // Per-tensor: scalar scale, optional scalar zp + if (zp_tensor_proto && zp_tensor_proto->dims_size() != 0) { + return false; + } + } else if (scale_rank == 1 && scale_tensor_proto->dims()[0] == N) { + // Per-channel (axis=1): scale shape [N], axis must be 1 + const auto a_iter = dq_attrs.find("axis"); + // DQ default axis is 1, so absent axis is OK + if (a_iter != dq_attrs.end() && a_iter->second.i() != 1) { + return false; + } + if (zp_tensor_proto && (zp_tensor_proto->dims_size() != 1 || zp_tensor_proto->dims()[0] != N)) { + return false; + } + } else { + // Unsupported quantization granularity + return false; + } + } + + return true; +} + +// Validate Gemm attributes for DQ->MatMulNBits fusion. +// Gemm must be equivalent to MatMul: alpha=1, transA=0, transB=0. +// If bias exists, beta must be 1 and bias shape must be [N]. +static bool ValidateGemmForDQMatMulNBits(const Graph& graph, const Node& gemm_node, const Node& weight_dq_node) { + if (const auto* alpha_attr = graph_utils::GetNodeAttribute(gemm_node, "alpha"); + alpha_attr && std::abs(alpha_attr->f() - 1.0f) > 1e-6f) + return false; + if (const auto* trans_a = graph_utils::GetNodeAttribute(gemm_node, "transA"); + trans_a && trans_a->i() != 0) + return false; + if (const auto* trans_b = graph_utils::GetNodeAttribute(gemm_node, "transB"); + trans_b && trans_b->i() != 0) return false; + + const auto& inputs = gemm_node.InputDefs(); + if (inputs.size() > 2 && inputs[2] && inputs[2]->Exists()) { + // Bias exists — beta must be 1.0 + if (const auto* beta_attr = graph_utils::GetNodeAttribute(gemm_node, "beta"); + beta_attr && std::abs(beta_attr->f() - 1.0f) > 1e-6f) + return false; + + // Bias shape must be [N] where N = weight dim 1. Prefer reading N and + // bias length from constant initializers when available, and fall back to + // NodeArg::Shape(). + const auto* weight_arg = weight_dq_node.InputDefs()[0]; + const auto* weight_initializer = graph.GetConstantInitializer(weight_arg->Name(), true); + int64_t N = -1; + + if (weight_initializer) { + if (weight_initializer->dims_size() != 2) { + return false; + } + N = weight_initializer->dims(1); + } else { + const auto* weight_shape = weight_arg->Shape(); + if (!weight_shape || weight_shape->dim_size() != 2 || + !utils::HasDimValue(weight_shape->dim(1))) { + return false; + } + N = weight_shape->dim(1).dim_value(); + } + + const auto* bias_arg = inputs[2]; + const auto* bias_initializer = graph.GetConstantInitializer(bias_arg->Name(), true); + + if (bias_initializer) { + if (bias_initializer->dims_size() != 1 || + bias_initializer->dims(0) != N) { + return false; + } + } else { + const auto* bias_shape = bias_arg->Shape(); + if (!bias_shape || bias_shape->dim_size() != 1 || + !utils::HasDimValue(bias_shape->dim(0)) || + bias_shape->dim(0).dim_value() != N) { + return false; + } + } } return true; @@ -637,18 +731,55 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod } const auto& graph = graph_viewer.GetGraph(); + const bool is_gemm = node.OpType() == "Gemm"; + + if (is_gemm) { + // Gemm: accept 1 DQ (weight only) or 2 DQs (weight + bias). + if (dq_nodes.size() < 1 || dq_nodes.size() > 2) { + return false; + } + } else { + // MatMul: exactly 1 DQ input + if (dq_nodes.size() != 1) { + return false; + } + } - // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output - if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { + // Find the weight DQ node — the one feeding input 1 (B) + const Node* weight_dq = nullptr; + for (const auto* dq : dq_nodes) { + if (node.InputDefs()[1] == dq->OutputDefs()[0]) { + weight_dq = dq; + break; + } + } + + if (!weight_dq) { return false; } - // DQ must be MatMul's the second input - if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { + // Weight DQ must have exactly 1 output edge and not be a graph output + if (!optimizer_utils::CheckOutputEdges(graph, *weight_dq, 1)) { return false; } - return ValidateBlockwiseDQForMatMulNBits(graph, *dq_nodes[0]); + if (is_gemm) { + // If there's a second DQ node (for bias), it must feed input 2 + if (dq_nodes.size() == 2) { + const Node* bias_dq = (dq_nodes[0] == weight_dq) ? dq_nodes[1] : dq_nodes[0]; + if (node.InputDefs().size() <= 2 || !node.InputDefs()[2] || + node.InputDefs()[2] != bias_dq->OutputDefs()[0]) { + return false; + } + } + + // Validate Gemm attributes (alpha=1, transA=0, transB=0, beta=1 if bias) + if (!ValidateGemmForDQMatMulNBits(graph, node, *weight_dq)) { + return false; + } + } + + return ValidateDQForMatMulNBits(graph, *weight_dq); } bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, @@ -701,6 +832,13 @@ void GemmSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex); } +void DQMatMulToMatMulNBitsSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { + // Keep only the weight DQ (first entry). If a Gemm has a bias DQ, it will be in + // position 1 — trim it so RemoveNodes does not delete it. The bias DQ's output + // is wired to MatMulNBits input 5 in ProcessNewNode. + builder.input_nodes.resize(1); +} + bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 79c374b301442..10d307b4a003c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -454,11 +454,15 @@ class MatMulSelector : public BaseSelector { compatible_providers) {} }; -// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +// Convert "1 DQ node for input B -> MatMul/Gemm" to "MatMulNBits" class DQMatMulToMatMulNBitsSelector : public BaseSelector { public: explicit DQMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) : BaseSelector(std::make_unique(), compatible_providers) {} + + // Only keep the weight DQ in the selection. Any bias DQ (for Gemm) is excluded + // so that RemoveNodes does not remove it — its output is wired through to MatMulNBits. + void UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const override; }; // Input: DQ nodes for A, B and optional C diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index 5d7eda39be271..03005e3a07386 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -114,43 +114,15 @@ RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ) { // DQ contrib op schema is not updated to support blocked quantization + // Rejection doesn't depend on type/zp/accuracy_level — keep representative combos only. RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ_Cuda) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - ; - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); } // Input2 @@ -179,7 +151,7 @@ RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); - auto scale_shape = std::vector{weight_shape}; + std::vector scale_shape = weight_shape; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); if constexpr (use_zp) { @@ -224,42 +196,13 @@ RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput_Cuda) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - ; - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); } // Input1 @@ -295,7 +238,7 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); - auto scale_shape = std::vector{weight_shape}; + std::vector scale_shape = weight_shape; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); if constexpr (use_zp) { @@ -353,52 +296,27 @@ TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch) { - // DQ contrib op schema is not updated to support blocked quantization + // One representative type combo per rejection scenario (type doesn't affect rejection logic). // block size too small RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); // block size not 2's power - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); // not axis 0 - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); // not rank 2 - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch_Cuda) { - // DQ contrib op schema is not updated to support blocked quantization + // One representative type combo per rejection scenario. // block size too small RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); // block size not 2's power - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); - ; - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); // not axis 0 - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); // not rank 2 - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); } // Input1 @@ -727,7 +645,7 @@ RunDQMatMulFP16Converted(const std::vector& input1_shape, utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); - auto scale_shape = std::vector{weight_shape}; + std::vector scale_shape = weight_shape; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); @@ -780,6 +698,602 @@ TEST(QDQTransformerTests, DQMatMulFP16ConvertedToMatMulNBits) { RunDQMatMulFP16Converted({12, 32}, {32, 16}, 0, 16, 0); } +// Per-tensor DQ -> MatMul conversion to MatMulNBits +// DQ has scalar scale (and optional scalar zero-point), no block_size attribute. +// Input1 +// | DQ(per-tensor) +// \ / +// MatMul +// | +// output +template +void RunDQMatMulPerTensorConverted(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + // Scalar scale (per-tensor) + auto* scale_arg = builder.MakeInitializer({}, {10.0f}); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer({}, std::vector{T(1, 0)}); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 0.01 /*per_sample_tolerance - higher due to blockwise accumulation reordering*/, + 5e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerTensorConvertedToMatMulNBits) { + // Per-tensor: cover both types and a non-divisible K case. + RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); + RunDQMatMulPerTensorConverted({12, 37}, {37, 16}, 0); +} + +// Per-channel (axis=1) DQ -> MatMul conversion to MatMulNBits +// DQ has 1D scale shape [N], axis=1, no block_size attribute. +// Input1 +// | DQ(per-channel axis=1) +// \ / +// MatMul +// | +// output +template +void RunDQMatMulPerChannelConverted(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + int64_t N = weight_shape[1]; + // 1D scale shape [N] for per-channel (axis=1) + auto* scale_arg = builder.MakeInitializer({N}, 8.0f, 12.0f); + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(1)), attrs); + + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(std::vector{N}, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerChannelConvertedToMatMulNBits) { + RunDQMatMulPerChannelConverted({12, 37}, {37, 16}, 0); +} + +// Negative test: per-axis axis=0 with 1D scale should NOT fuse +template +void RunDQMatMulPerAxisAxis0NotConverted(const std::vector& input1_shape, + const std::vector& weight_shape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + int64_t K = weight_shape[0]; + // 1D scale shape [K] for per-axis axis=0 — should NOT match + auto* scale_arg = builder.MakeInitializer({K}, 8.0f, 12.0f); + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), attrs); + + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn = [](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "0"); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerAxisAxis0NotConvertedToMatMulNBits) { + RunDQMatMulPerAxisAxis0NotConverted({12, 32}, {32, 16}); +} + +// Per-tensor DQ -> MatMul with configurable block_size session option +template +void RunDQMatMulPerTensorWithBlockSize(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t block_size_option) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + auto* scale_arg = builder.MakeInitializer({}, {10.0f}); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer({}, std::vector{T(1, 0)}); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + + // Verify the MatMulNBits node has the expected block_size attribute + for (const auto& node : session.GetGraph().Nodes()) { + if (node.OpType() == "MatMulNBits") { + auto& attrs = node.GetAttributes(); + auto bs_iter = attrs.find("block_size"); + ASSERT_NE(bs_iter, attrs.end()); + int64_t expected_bs = block_size_option > 0 ? block_size_option : 32; // default is 32 + EXPECT_EQ(bs_iter->second.i(), expected_bs); + } + } + }; + + std::function add_session_options_fn = + [block_size_option](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "0"); + std::ignore = sess_opts.config_options.AddConfigEntry( + kOrtSessionOptionsQDQMatMulNBitsBlockSize, + std::to_string(block_size_option).c_str()); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerTensorWithBlockSizeOption) { + // Default block_size (0 -> 32) + RunDQMatMulPerTensorWithBlockSize({12, 32}, {32, 16}, 0); + // Explicit block_size=16 + RunDQMatMulPerTensorWithBlockSize({12, 32}, {32, 16}, 16); +} + +// UINT8 per-tensor DQ -> MatMul -> MatMulNBits +// Tests shapes from real models including small dimensions (N=1, N=8). +template +void RunDQMatMulPerTensorUint8Converted(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, uint8_t(0), uint8_t(255)); + auto* dq_output = builder.MakeIntermediate(); + + // Scalar scale (per-tensor) + auto* scale_arg = builder.MakeInitializer({}, {0.05f}); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer({}, {uint8_t(128)}); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 0.01 /*per_sample_tolerance - higher due to blockwise accumulation reordering*/, + 5e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerTensorUint8ConvertedToMatMulNBits) { + RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 8}, 0); +} + +// --------------------------------------------------------------------------- +// DQ -> Gemm tests for MatMulNBits fusion +// --------------------------------------------------------------------------- + +// Input1 +// | DQ (4-bit weight) +// \ / +// Gemm +// | +// output +// Gemm has no bias, equivalent to MatMul. Should fuse to MatMulNBits. +template +typename std::enable_if || std::is_same_v, void>::type +RunDQGemmConvertedNoBias(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + std::vector scale_shape = weight_shape; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + auto* scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + } + + builder.AddNode("Gemm", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 2e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_NoBias) { + RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); +} + +// Input1 +// | DQ (4-bit weight) bias (float) +// \ / / +// Gemm +// | +// output +// Gemm has a direct (non-DQ) float bias. Should fuse to MatMulNBits with bias at input 5. +template +typename std::enable_if || std::is_same_v, void>::type +RunDQGemmConvertedWithBias(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + std::vector scale_shape = weight_shape; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + int64_t N = weight_shape[1]; + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + auto* scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + } + + auto* bias_arg = builder.MakeInitializer({N}, std::vector(static_cast(N), 0.5f)); + builder.AddNode("Gemm", {input_arg, dq_output, bias_arg}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 2e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_WithBias) { + RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); +} + +// Input1 +// | DQ (4-bit weight) DQ (bias) +// \ / / +// Gemm +// | +// output +// Gemm has a bias from DQ. Weight DQ fused into MatMulNBits, bias DQ stays alive, +// bias DQ output wired to MatMulNBits input 5. +template +typename std::enable_if || std::is_same_v, void>::type +RunDQGemmConvertedWithDQBias(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // Weight DQ + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + std::vector scale_shape = weight_shape; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + int64_t N = weight_shape[1]; + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + auto* scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + } + + // Bias DQ (int8 quantized bias -> float) + auto* bias_quantized = builder.MakeInitializer({N}, std::vector(static_cast(N), 5)); + auto* bias_scale = builder.MakeInitializer({}, std::vector{0.1f}); + auto* bias_zp = builder.MakeInitializer({}, std::vector{0}); + auto* bias_dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {bias_quantized, bias_scale, bias_zp}, {bias_dq_output}); + + builder.AddNode("Gemm", {input_arg, dq_output, bias_dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + // Weight DQ removed, bias DQ stays + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 2e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_WithDQBias) { + RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); +} + +// Negative test: DQ -> Gemm with transB=1 should NOT be fused. +TEST(QDQTransformerTests, DQGemmNotConvertedToMatMulNBits_TransB) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({12, 37}, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // With transB=1, Gemm transposes B at runtime: weight shape [N,K]=[12,37], transposed to [K,N]=[37,12]. + // DQ weight shape is [12,37] (N=12, K=37 after transpose). + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", static_cast(16)), dq_attrs); + auto* weight_arg = builder.MakeInitializer({12, 37}, Int4x2(Int4x2::min_val, 0), Int4x2(Int4x2::max_val, 0)); + auto* scales_arg = builder.MakeInitializer({1, 37}, 8.0f, 12.0f); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + + NodeAttributes gemm_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("transB", static_cast(1)), gemm_attrs); + builder.AddNode("Gemm", {input_arg, dq_output}, {output_arg}, "", &gemm_attrs); + }; + + auto check_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["Gemm"], 1); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5, 2e-5); +} + +// Negative test: DQ -> Gemm with alpha != 1.0 should NOT be fused. +TEST(QDQTransformerTests, DQGemmNotConvertedToMatMulNBits_Alpha) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({12, 37}, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", static_cast(16)), dq_attrs); + auto* weight_arg = builder.MakeInitializer({37, 12}, Int4x2(Int4x2::min_val, 0), Int4x2(Int4x2::max_val, 0)); + auto* scales_arg = builder.MakeInitializer({3, 12}, 8.0f, 12.0f); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + + NodeAttributes gemm_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("alpha", 2.0f), gemm_attrs); + builder.AddNode("Gemm", {input_arg, dq_output}, {output_arg}, "", &gemm_attrs); + }; + + auto check_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["Gemm"], 1); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5, 2e-5); +} + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace test From 883b461c8ae35cd69105dbd44e93fa6f98e0f7b2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:09:24 +0000 Subject: [PATCH 10/17] Bump rollup from 4.35.0 to 4.59.0 in /js/web/test/e2e/exports/testcases/vite-default (#27463) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [rollup](https://github.com/rollup/rollup) from 4.35.0 to 4.59.0.
Release notes

Sourced from rollup's releases.

v4.59.0

4.59.0

2026-02-22

Features

  • Throw when the generated bundle contains paths that would leave the output directory (#6276)

Pull Requests

v4.58.0

4.58.0

2026-02-20

Features

  • Also support __NO_SIDE_EFFECTS__ annotation before variable declarations declaring function expressions (#6272)

Pull Requests

v4.57.1

4.57.1

2026-01-30

Bug Fixes

  • Fix heap corruption issue in Windows (#6251)
  • Ensure exports of a dynamic import are fully included when called from a try...catch (#6254)

Pull Requests

... (truncated)

Changelog

Sourced from rollup's changelog.

4.59.0

2026-02-22

Features

  • Throw when the generated bundle contains paths that would leave the output directory (#6276)

Pull Requests

4.58.0

2026-02-20

Features

  • Also support __NO_SIDE_EFFECTS__ annotation before variable declarations declaring function expressions (#6272)

Pull Requests

4.57.1

2026-01-30

Bug Fixes

  • Fix heap corruption issue in Windows (#6251)
  • Ensure exports of a dynamic import are fully included when called from a try...catch (#6254)

Pull Requests

... (truncated)

Commits
Maintainer changes

This version was pushed to npm by [GitHub Actions](https://www.npmjs.com/~GitHub Actions), a new releaser for rollup since your current version.

Install script changes

This version modifies prepare script that runs during installation. Review the package contents before updating.


[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=rollup&package-manager=npm_and_yarn&previous-version=4.35.0&new-version=4.59.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../testcases/vite-default/package-lock.json | 260 ++++++++++++------ 1 file changed, 175 insertions(+), 85 deletions(-) diff --git a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json index 62b4df5806eda..ed0559c85ee1b 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json @@ -493,9 +493,9 @@ "license": "MIT" }, "node_modules/@rollup/rollup-android-arm-eabi": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.35.0.tgz", - "integrity": "sha512-uYQ2WfPaqz5QtVgMxfN6NpLD+no0MYHDBywl7itPYd3K5TjjSghNKmX8ic9S8NU8w81NVhJv/XojcHptRly7qQ==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.59.0.tgz", + "integrity": "sha512-upnNBkA6ZH2VKGcBj9Fyl9IGNPULcjXRlg0LLeaioQWueH30p6IXtJEbKAgvyv+mJaMxSm1l6xwDXYjpEMiLMg==", "cpu": [ "arm" ], @@ -507,9 +507,9 @@ ] }, "node_modules/@rollup/rollup-android-arm64": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.35.0.tgz", - "integrity": "sha512-FtKddj9XZudurLhdJnBl9fl6BwCJ3ky8riCXjEw3/UIbjmIY58ppWwPEvU3fNu+W7FUsAsB1CdH+7EQE6CXAPA==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.59.0.tgz", + "integrity": "sha512-hZ+Zxj3SySm4A/DylsDKZAeVg0mvi++0PYVceVyX7hemkw7OreKdCvW2oQ3T1FMZvCaQXqOTHb8qmBShoqk69Q==", "cpu": [ "arm64" ], @@ -521,9 +521,9 @@ ] }, "node_modules/@rollup/rollup-darwin-arm64": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.35.0.tgz", - "integrity": "sha512-Uk+GjOJR6CY844/q6r5DR/6lkPFOw0hjfOIzVx22THJXMxktXG6CbejseJFznU8vHcEBLpiXKY3/6xc+cBm65Q==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.59.0.tgz", + "integrity": "sha512-W2Psnbh1J8ZJw0xKAd8zdNgF9HRLkdWwwdWqubSVk0pUuQkoHnv7rx4GiF9rT4t5DIZGAsConRE3AxCdJ4m8rg==", "cpu": [ "arm64" ], @@ -535,9 +535,9 @@ ] }, "node_modules/@rollup/rollup-darwin-x64": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.35.0.tgz", - "integrity": "sha512-3IrHjfAS6Vkp+5bISNQnPogRAW5GAV1n+bNCrDwXmfMHbPl5EhTmWtfmwlJxFRUCBZ+tZ/OxDyU08aF6NI/N5Q==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.59.0.tgz", + "integrity": "sha512-ZW2KkwlS4lwTv7ZVsYDiARfFCnSGhzYPdiOU4IM2fDbL+QGlyAbjgSFuqNRbSthybLbIJ915UtZBtmuLrQAT/w==", "cpu": [ "x64" ], @@ -549,9 +549,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-arm64": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.35.0.tgz", - "integrity": "sha512-sxjoD/6F9cDLSELuLNnY0fOrM9WA0KrM0vWm57XhrIMf5FGiN8D0l7fn+bpUeBSU7dCgPV2oX4zHAsAXyHFGcQ==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.59.0.tgz", + "integrity": "sha512-EsKaJ5ytAu9jI3lonzn3BgG8iRBjV4LxZexygcQbpiU0wU0ATxhNVEpXKfUa0pS05gTcSDMKpn3Sx+QB9RlTTA==", "cpu": [ "arm64" ], @@ -563,9 +563,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-x64": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.35.0.tgz", - "integrity": "sha512-2mpHCeRuD1u/2kruUiHSsnjWtHjqVbzhBkNVQ1aVD63CcexKVcQGwJ2g5VphOd84GvxfSvnnlEyBtQCE5hxVVw==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.59.0.tgz", + "integrity": "sha512-d3DuZi2KzTMjImrxoHIAODUZYoUUMsuUiY4SRRcJy6NJoZ6iIqWnJu9IScV9jXysyGMVuW+KNzZvBLOcpdl3Vg==", "cpu": [ "x64" ], @@ -577,9 +577,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-gnueabihf": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.35.0.tgz", - "integrity": "sha512-mrA0v3QMy6ZSvEuLs0dMxcO2LnaCONs1Z73GUDBHWbY8tFFocM6yl7YyMu7rz4zS81NDSqhrUuolyZXGi8TEqg==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.59.0.tgz", + "integrity": "sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==", "cpu": [ "arm" ], @@ -591,9 +591,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-musleabihf": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.35.0.tgz", - "integrity": "sha512-DnYhhzcvTAKNexIql8pFajr0PiDGrIsBYPRvCKlA5ixSS3uwo/CWNZxB09jhIapEIg945KOzcYEAGGSmTSpk7A==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.59.0.tgz", + "integrity": "sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==", "cpu": [ "arm" ], @@ -605,9 +605,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.35.0.tgz", - "integrity": "sha512-uagpnH2M2g2b5iLsCTZ35CL1FgyuzzJQ8L9VtlJ+FckBXroTwNOaD0z0/UF+k5K3aNQjbm8LIVpxykUOQt1m/A==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.59.0.tgz", + "integrity": "sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==", "cpu": [ "arm64" ], @@ -619,9 +619,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-musl": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.35.0.tgz", - "integrity": "sha512-XQxVOCd6VJeHQA/7YcqyV0/88N6ysSVzRjJ9I9UA/xXpEsjvAgDTgH3wQYz5bmr7SPtVK2TsP2fQ2N9L4ukoUg==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.59.0.tgz", + "integrity": "sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==", "cpu": [ "arm64" ], @@ -632,10 +632,10 @@ "linux" ] }, - "node_modules/@rollup/rollup-linux-loongarch64-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loongarch64-gnu/-/rollup-linux-loongarch64-gnu-4.35.0.tgz", - "integrity": "sha512-5pMT5PzfgwcXEwOaSrqVsz/LvjDZt+vQ8RT/70yhPU06PTuq8WaHhfT1LW+cdD7mW6i/J5/XIkX/1tCAkh1W6g==", + "node_modules/@rollup/rollup-linux-loong64-gnu": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.59.0.tgz", + "integrity": "sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==", "cpu": [ "loong64" ], @@ -646,10 +646,38 @@ "linux" ] }, - "node_modules/@rollup/rollup-linux-powerpc64le-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-powerpc64le-gnu/-/rollup-linux-powerpc64le-gnu-4.35.0.tgz", - "integrity": "sha512-c+zkcvbhbXF98f4CtEIP1EBA/lCic5xB0lToneZYvMeKu5Kamq3O8gqrxiYYLzlZH6E3Aq+TSW86E4ay8iD8EA==", + "node_modules/@rollup/rollup-linux-loong64-musl": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.59.0.tgz", + "integrity": "sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-ppc64-gnu": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.59.0.tgz", + "integrity": "sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-ppc64-musl": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.59.0.tgz", + "integrity": "sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==", "cpu": [ "ppc64" ], @@ -661,9 +689,23 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.35.0.tgz", - "integrity": "sha512-s91fuAHdOwH/Tad2tzTtPX7UZyytHIRR6V4+2IGlV0Cej5rkG0R61SX4l4y9sh0JBibMiploZx3oHKPnQBKe4g==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.59.0.tgz", + "integrity": "sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-riscv64-musl": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.59.0.tgz", + "integrity": "sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==", "cpu": [ "riscv64" ], @@ -675,9 +717,9 @@ ] }, "node_modules/@rollup/rollup-linux-s390x-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.35.0.tgz", - "integrity": "sha512-hQRkPQPLYJZYGP+Hj4fR9dDBMIM7zrzJDWFEMPdTnTy95Ljnv0/4w/ixFw3pTBMEuuEuoqtBINYND4M7ujcuQw==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.59.0.tgz", + "integrity": "sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==", "cpu": [ "s390x" ], @@ -689,9 +731,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.35.0.tgz", - "integrity": "sha512-Pim1T8rXOri+0HmV4CdKSGrqcBWX0d1HoPnQ0uw0bdp1aP5SdQVNBy8LjYncvnLgu3fnnCt17xjWGd4cqh8/hA==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.59.0.tgz", + "integrity": "sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==", "cpu": [ "x64" ], @@ -703,9 +745,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-musl": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.35.0.tgz", - "integrity": "sha512-QysqXzYiDvQWfUiTm8XmJNO2zm9yC9P/2Gkrwg2dH9cxotQzunBHYr6jk4SujCTqnfGxduOmQcI7c2ryuW8XVg==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.59.0.tgz", + "integrity": "sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==", "cpu": [ "x64" ], @@ -716,10 +758,38 @@ "linux" ] }, + "node_modules/@rollup/rollup-openbsd-x64": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.59.0.tgz", + "integrity": "sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ] + }, + "node_modules/@rollup/rollup-openharmony-arm64": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.59.0.tgz", + "integrity": "sha512-tt9KBJqaqp5i5HUZzoafHZX8b5Q2Fe7UjYERADll83O4fGqJ49O1FsL6LpdzVFQcpwvnyd0i+K/VSwu/o/nWlA==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ] + }, "node_modules/@rollup/rollup-win32-arm64-msvc": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.35.0.tgz", - "integrity": "sha512-OUOlGqPkVJCdJETKOCEf1mw848ZyJ5w50/rZ/3IBQVdLfR5jk/6Sr5m3iO2tdPgwo0x7VcncYuOvMhBWZq8ayg==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.59.0.tgz", + "integrity": "sha512-V5B6mG7OrGTwnxaNUzZTDTjDS7F75PO1ae6MJYdiMu60sq0CqN5CVeVsbhPxalupvTX8gXVSU9gq+Rx1/hvu6A==", "cpu": [ "arm64" ], @@ -731,9 +801,9 @@ ] }, "node_modules/@rollup/rollup-win32-ia32-msvc": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.35.0.tgz", - "integrity": "sha512-2/lsgejMrtwQe44glq7AFFHLfJBPafpsTa6JvP2NGef/ifOa4KBoglVf7AKN7EV9o32evBPRqfg96fEHzWo5kw==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.59.0.tgz", + "integrity": "sha512-UKFMHPuM9R0iBegwzKF4y0C4J9u8C6MEJgFuXTBerMk7EJ92GFVFYBfOZaSGLu6COf7FxpQNqhNS4c4icUPqxA==", "cpu": [ "ia32" ], @@ -744,10 +814,24 @@ "win32" ] }, + "node_modules/@rollup/rollup-win32-x64-gnu": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.59.0.tgz", + "integrity": "sha512-laBkYlSS1n2L8fSo1thDNGrCTQMmxjYY5G0WFWjFFYZkKPjsMBsgJfGf4TLxXrF6RyhI60L8TMOjBMvXiTcxeA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, "node_modules/@rollup/rollup-win32-x64-msvc": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.35.0.tgz", - "integrity": "sha512-PIQeY5XDkrOysbQblSW7v3l1MDZzkTEzAfTPkj5VAu3FW8fS4ynyLg2sINp0fp3SjZ8xkRYpLqoKcYqAkhU1dw==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.59.0.tgz", + "integrity": "sha512-2HRCml6OztYXyJXAvdDXPKcawukWY2GpR5/nxKp4iBgiO3wcoEGkAaqctIbZcNB6KlUQBIqt8VYkNSj2397EfA==", "cpu": [ "x64" ], @@ -759,9 +843,9 @@ ] }, "node_modules/@types/estree": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.6.tgz", - "integrity": "sha512-AYnb1nQyY49te+VRAVgmzfcgjYS91mY5P0TKUDCLEM+gNnA+3T6rWITXRLYCpahpqSQbN5cE+gHpnPyXjHWxcw==", + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", + "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", "dev": true, "license": "MIT" }, @@ -1049,13 +1133,13 @@ } }, "node_modules/rollup": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.35.0.tgz", - "integrity": "sha512-kg6oI4g+vc41vePJyO6dHt/yl0Rz3Thv0kJeVQ3D1kS3E5XSuKbPc29G4IpT/Kv1KQwgHVcN+HtyS+HYLNSvQg==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.59.0.tgz", + "integrity": "sha512-2oMpl67a3zCH9H79LeMcbDhXW/UmWG/y2zuqnF2jQq5uq9TbM9TVyXvA4+t+ne2IIkBdrLpAaRQAvo7YI/Yyeg==", "dev": true, "license": "MIT", "dependencies": { - "@types/estree": "1.0.6" + "@types/estree": "1.0.8" }, "bin": { "rollup": "dist/bin/rollup" @@ -1065,25 +1149,31 @@ "npm": ">=8.0.0" }, "optionalDependencies": { - "@rollup/rollup-android-arm-eabi": "4.35.0", - "@rollup/rollup-android-arm64": "4.35.0", - "@rollup/rollup-darwin-arm64": "4.35.0", - "@rollup/rollup-darwin-x64": "4.35.0", - "@rollup/rollup-freebsd-arm64": "4.35.0", - "@rollup/rollup-freebsd-x64": "4.35.0", - "@rollup/rollup-linux-arm-gnueabihf": "4.35.0", - "@rollup/rollup-linux-arm-musleabihf": "4.35.0", - "@rollup/rollup-linux-arm64-gnu": "4.35.0", - "@rollup/rollup-linux-arm64-musl": "4.35.0", - "@rollup/rollup-linux-loongarch64-gnu": "4.35.0", - "@rollup/rollup-linux-powerpc64le-gnu": "4.35.0", - "@rollup/rollup-linux-riscv64-gnu": "4.35.0", - "@rollup/rollup-linux-s390x-gnu": "4.35.0", - "@rollup/rollup-linux-x64-gnu": "4.35.0", - "@rollup/rollup-linux-x64-musl": "4.35.0", - "@rollup/rollup-win32-arm64-msvc": "4.35.0", - "@rollup/rollup-win32-ia32-msvc": "4.35.0", - "@rollup/rollup-win32-x64-msvc": "4.35.0", + "@rollup/rollup-android-arm-eabi": "4.59.0", + "@rollup/rollup-android-arm64": "4.59.0", + "@rollup/rollup-darwin-arm64": "4.59.0", + "@rollup/rollup-darwin-x64": "4.59.0", + "@rollup/rollup-freebsd-arm64": "4.59.0", + "@rollup/rollup-freebsd-x64": "4.59.0", + "@rollup/rollup-linux-arm-gnueabihf": "4.59.0", + "@rollup/rollup-linux-arm-musleabihf": "4.59.0", + "@rollup/rollup-linux-arm64-gnu": "4.59.0", + "@rollup/rollup-linux-arm64-musl": "4.59.0", + "@rollup/rollup-linux-loong64-gnu": "4.59.0", + "@rollup/rollup-linux-loong64-musl": "4.59.0", + "@rollup/rollup-linux-ppc64-gnu": "4.59.0", + "@rollup/rollup-linux-ppc64-musl": "4.59.0", + "@rollup/rollup-linux-riscv64-gnu": "4.59.0", + "@rollup/rollup-linux-riscv64-musl": "4.59.0", + "@rollup/rollup-linux-s390x-gnu": "4.59.0", + "@rollup/rollup-linux-x64-gnu": "4.59.0", + "@rollup/rollup-linux-x64-musl": "4.59.0", + "@rollup/rollup-openbsd-x64": "4.59.0", + "@rollup/rollup-openharmony-arm64": "4.59.0", + "@rollup/rollup-win32-arm64-msvc": "4.59.0", + "@rollup/rollup-win32-ia32-msvc": "4.59.0", + "@rollup/rollup-win32-x64-gnu": "4.59.0", + "@rollup/rollup-win32-x64-msvc": "4.59.0", "fsevents": "~2.3.2" } }, From 793153bc184141a389b15920c500360bd75a3c3f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:45:38 +0000 Subject: [PATCH 11/17] Bump flatted from 3.3.3 to 3.4.2 in /js/react_native/e2e (#27785) Bumps [flatted](https://github.com/WebReflection/flatted) from 3.3.3 to 3.4.2.
Commits
  • 3bf0909 3.4.2
  • 885ddcc fix CWE-1321
  • 0bdba70 added flatted-view to the benchmark
  • 2a02dce 3.4.1
  • fba4e8f Merge pull request #89 from WebReflection/python-fix
  • 5fe8648 added "when in Rome" also a test for PHP
  • 53517ad some minor improvement
  • b3e2a0c Fixing recursion issue in Python too
  • c4b46db Add SECURITY.md for security policy and reporting
  • f86d071 Create dependabot.yml for version updates
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=flatted&package-manager=npm_and_yarn&previous-version=3.3.3&new-version=3.4.2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/e2e/package-lock.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index d8a273ef6825f..3f9ff05a72f97 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -7149,9 +7149,9 @@ } }, "node_modules/flatted": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", - "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", + "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", "dev": true, "license": "ISC" }, From faad20f9d3264c7f3b6d4e4398990e13ee864512 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:47:26 +0000 Subject: [PATCH 12/17] Bump flatted from 3.3.3 to 3.4.2 in /js (#27799) Bumps [flatted](https://github.com/WebReflection/flatted) from 3.3.3 to 3.4.2.
Commits
  • 3bf0909 3.4.2
  • 885ddcc fix CWE-1321
  • 0bdba70 added flatted-view to the benchmark
  • 2a02dce 3.4.1
  • fba4e8f Merge pull request #89 from WebReflection/python-fix
  • 5fe8648 added "when in Rome" also a test for PHP
  • 53517ad some minor improvement
  • b3e2a0c Fixing recursion issue in Python too
  • c4b46db Add SECURITY.md for security policy and reporting
  • f86d071 Create dependabot.yml for version updates
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=flatted&package-manager=npm_and_yarn&previous-version=3.3.3&new-version=3.4.2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/package-lock.json | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/js/package-lock.json b/js/package-lock.json index 22fb22757e94b..1ba8fc900bbd8 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4,7 +4,6 @@ "requires": true, "packages": { "": { - "name": "js", "license": "MIT", "devDependencies": { "@eslint/compat": "^1.4.0", @@ -3014,11 +3013,10 @@ } }, "node_modules/flatted": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", - "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", - "dev": true, - "license": "ISC" + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", + "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", + "dev": true }, "node_modules/for-each": { "version": "0.3.5", @@ -7946,9 +7944,9 @@ } }, "flatted": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", - "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", + "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", "dev": true }, "for-each": { From 38a2625e365e2ab149d5912141c78f974f466576 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 24 Mar 2026 10:27:22 -0700 Subject: [PATCH 13/17] [MLAS] Add fused Silu and Gelu kernels for AVX512 (#27690) ### Description Add fused Silu and Exact Gelu (Erf based) for AVX512f Silu benchmarks: image GELU exact (Erf) benchmarks: image ### Motivation and Context Improve performance on AVX512F Silu shows small regression at B=1 but I don't think the absolute difference is much --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- cmake/onnxruntime_mlas.cmake | 14 +- onnxruntime/contrib_ops/cpu/activations.h | 28 +- onnxruntime/core/mlas/inc/mlas.h | 24 ++ onnxruntime/core/mlas/lib/gelu.cpp | 65 ++++ .../lib/intrinsics/avx512/gelu_avx512f.cpp | 219 ++++++++++++++ .../lib/intrinsics/avx512/silu_avx512f.cpp | 140 +++++++++ onnxruntime/core/mlas/lib/mlasi.h | 6 + onnxruntime/core/mlas/lib/platform.cpp | 5 +- onnxruntime/core/mlas/lib/silu.cpp | 51 ++++ onnxruntime/core/providers/cpu/tensor/gelu.cc | 13 +- onnxruntime/core/providers/cpu/tensor/gelu.h | 2 + .../test/mlas/bench/bench_transcendental.cpp | 189 ++++++++++++ .../unittest/test_transcendental_avx512.cpp | 285 ++++++++++++++++++ 13 files changed, 1015 insertions(+), 26 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/gelu.cpp create mode 100644 onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp create mode 100644 onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp create mode 100644 onnxruntime/core/mlas/lib/silu.cpp create mode 100644 onnxruntime/test/mlas/bench/bench_transcendental.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 0156e46b86bc4..4f75a8b105ec2 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -34,6 +34,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.h ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp + ${MLAS_SRC_DIR}/silu.cpp + ${MLAS_SRC_DIR}/gelu.cpp ${MLAS_SRC_DIR}/compute.cpp ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp @@ -201,6 +203,14 @@ function(setup_mlas_source_for_windows) ) set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") + set(mlas_platform_srcs_avx512 + ${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ) + + set_source_files_properties(${mlas_platform_srcs_avx512} PROPERTIES COMPILE_FLAGS "/arch:AVX512") + target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/dgemm.cpp ${mlas_platform_srcs_avx} @@ -212,7 +222,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${mlas_platform_srcs_avx512} ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp @@ -764,6 +774,8 @@ endif() ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S + ${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp ) set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index f00fad809968f..71e0e8561e110 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -78,22 +78,22 @@ class QuickGelu : public OpKernel { T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - if (alpha_ != 1.0f) { - // TODO: Consider vectorizing this scalar multiplication. - // It needs exposing a new API in MLAS to take in a scalar - // that will be used in the elementwise multiplication. - // Estimate the cost-benefit tradeoff before proceeding - // with that optimization. - for (int64_t i = 0; i < count; i++) { - p_output[i] = p_input[i] * alpha_; - } - - MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); - } else { - // SILU activation - this needs no `alpha_` scaling as `alpha_` will be 1.0f - MlasComputeLogistic(p_input, p_output, onnxruntime::narrow(count)); + if (alpha_ == 1.0f) { + MlasComputeSilu(p_input, p_output, onnxruntime::narrow(count)); + return; } + // TODO: Consider vectorizing this scalar multiplication. + // It needs exposing a new API in MLAS to take in a scalar + // that will be used in the elementwise multiplication. + // Estimate the cost-benefit tradeoff before proceeding + // with that optimization. + for (int64_t i = 0; i < count; i++) { + p_output[i] = p_input[i] * alpha_; + } + + MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); + MlasEltwiseMul(p_input, p_output, p_output, onnxruntime::narrow(count)); }, 0); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 56849995656f3..2b446c4b2601b 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1113,6 +1113,30 @@ MlasComputeErf( size_t N ); +// +// Note: The Input and Output buffers for MlasComputeGeluErf must not overlap. +// In-place operation (e.g., passing the same buffer for both parameters) is unsupported. +// +void +MLASCALL +MlasComputeGeluErf( + const float* Input, + float* Output, + size_t N + ); + +// +// Note: The Input and Output buffers for MlasComputeSilu must not overlap. +// In-place operation (e.g., passing the same buffer for both parameters) is unsupported. +// +void +MLASCALL +MlasComputeSilu( + const float* Input, + float* Output, + size_t N + ); + template void MLASCALL diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp new file mode 100644 index 0000000000000..dc25611652c77 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -0,0 +1,65 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + gelu.cpp + +Abstract: + + This module implements routines to compute the exact Gelu function. + +--*/ + +#include "mlasi.h" + +namespace { + +constexpr float kInvSqrt2 = 0.70710678118654752440f; + +} // namespace + + +void +MLASCALL +MlasGeluErfKernel( + const float* Input, + float* Output, + size_t N + ) +{ + // This kernel is not buffer alias safe because it is implemented in + // multiple passes: first scale Input into Output, then apply erf in place, + // and finally combine that intermediate with the original Input values. + // Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements). + for (size_t i = 0; i < N; ++i) { + Output[i] = Input[i] * kInvSqrt2; + } + + MlasComputeErf(Output, Output, N); + + for (size_t i = 0; i < N; ++i) { + Output[i] = 0.5f * Input[i] * (Output[i] + 1.0f); + } +} + +void +MLASCALL +MlasComputeGeluErf( + const float* Input, + float* Output, + size_t N + ) +{ +#if defined(MLAS_TARGET_AMD64) + // TODO: Add an intermediate fused AVX2/FMA3 GELU(erf) path on AMD64. + // Today the dispatch jumps from the generic multi-pass implementation to + // AVX512F, so non-AVX512 x64 machines fall back to the generic kernel. + GetMlasPlatform().GeluErfKernelRoutine(Input, Output, N); +#else + MlasGeluErfKernel(Input, Output, N); +#endif +} diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp new file mode 100644 index 0000000000000..4a9f3a100ed65 --- /dev/null +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -0,0 +1,219 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + gelu_avx512f.cpp + +Abstract: + + This module implements routines to compute exact Gelu with AVX512F + intrinsics. + +--*/ + +#include + +#include "mlasi.h" + +namespace { + +struct GeluAvx512Constants { + static constexpr int32_t SignBitMask = INT32_MIN; + static constexpr float InvSqrt2 = 0.70710678118654752440f; + static constexpr float Half = 0.5f; + static constexpr float One = 1.0f; + + static constexpr float ErfUpperAbsRange = 3.925f; + static constexpr float ErfSplitBoundary = 0.921875f; + static constexpr float ErfSMALL_P0 = -5.99104969e-4f; + static constexpr float ErfSMALL_P1 = 4.99339588e-3f; + static constexpr float ErfSMALL_P2 = -2.67667342e-2f; + static constexpr float ErfSMALL_P3 = 1.12818025e-1f; + static constexpr float ErfSMALL_P4 = -3.76124859e-1f; + static constexpr float ErfSMALL_P5_Minus_One = 1.28379151e-1f; + static constexpr float ErfBIG_P0 = 1.72948930e-5f; + static constexpr float ErfBIG_P1 = -3.83208680e-4f; + static constexpr float ErfBIG_P2 = 3.88393435e-3f; + static constexpr float ErfBIG_P3 = -2.42545605e-2f; + static constexpr float ErfBIG_P4 = 1.06777847e-1f; + static constexpr float ErfBIG_P5 = 6.34846687e-1f; + static constexpr float ErfBIG_P6_Minus_One = 1.28717512e-1f; + static constexpr float ErfOne = 1.0f; + static constexpr float ExpLowerRange = -88.3762626647949f; + static constexpr float ExpLog2Reciprocal = 1.44269504088896341f; + static constexpr float ExpLog2Hi = -6.93145752e-1f; + static constexpr float ExpLog2Lo = -1.42860677e-6f; + static constexpr float ExpP0 = 1.38319808e-3f; + static constexpr float ExpP1 = 8.37550033e-3f; + static constexpr float ExpP2 = 4.16689515e-2f; + static constexpr float ExpP3 = 1.66664466e-1f; + static constexpr float ExpP4 = 4.99999851e-1f; + static constexpr float ExpP5 = 1.0f; + static constexpr float ExpP6 = 1.0f; + static constexpr float ExpC = 1.25829120e+7f; +}; + +struct GeluAvx512BroadcastConstants { + const __m512 NegZero = _mm512_castsi512_ps(_mm512_set1_epi32(GeluAvx512Constants::SignBitMask)); + const __m512 Zero = _mm512_setzero_ps(); + const __m512 InvSqrt2 = _mm512_set1_ps(GeluAvx512Constants::InvSqrt2); + const __m512 Half = _mm512_set1_ps(GeluAvx512Constants::Half); + const __m512 One = _mm512_set1_ps(GeluAvx512Constants::One); + const __m512 ErfUpperAbsRange = _mm512_set1_ps(GeluAvx512Constants::ErfUpperAbsRange); + const __m512 ErfSplitBoundary = _mm512_set1_ps(GeluAvx512Constants::ErfSplitBoundary); + const __m512 ErfSmallP0 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P0); + const __m512 ErfSmallP1 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P1); + const __m512 ErfSmallP2 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P2); + const __m512 ErfSmallP3 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P3); + const __m512 ErfSmallP4 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P4); + const __m512 ErfSmallP5MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P5_Minus_One); + const __m512 ErfBigP0 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P0); + const __m512 ErfBigP1 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P1); + const __m512 ErfBigP2 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P2); + const __m512 ErfBigP3 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P3); + const __m512 ErfBigP4 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P4); + const __m512 ErfBigP5 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P5); + const __m512 ErfBigP6MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P6_Minus_One); + const __m512 ErfOne = _mm512_set1_ps(GeluAvx512Constants::ErfOne); + const __m512 ExpLowerRange = _mm512_set1_ps(GeluAvx512Constants::ExpLowerRange); + const __m512 ExpLog2Reciprocal = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Reciprocal); + const __m512 ExpLog2Hi = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Hi); + const __m512 ExpLog2Lo = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Lo); + const __m512 ExpP0 = _mm512_set1_ps(GeluAvx512Constants::ExpP0); + const __m512 ExpP1 = _mm512_set1_ps(GeluAvx512Constants::ExpP1); + const __m512 ExpP2 = _mm512_set1_ps(GeluAvx512Constants::ExpP2); + const __m512 ExpP3 = _mm512_set1_ps(GeluAvx512Constants::ExpP3); + const __m512 ExpP4 = _mm512_set1_ps(GeluAvx512Constants::ExpP4); + const __m512 ExpP5 = _mm512_set1_ps(GeluAvx512Constants::ExpP5); + const __m512 ExpP6 = _mm512_set1_ps(GeluAvx512Constants::ExpP6); + const __m512 ExpC = _mm512_set1_ps(GeluAvx512Constants::ExpC); +}; + +MLAS_FORCEINLINE __m512 +MlasGeluErfExpVectorAvx512( + __m512 Value, + const GeluAvx512BroadcastConstants& Constants + ) +{ + __m512 R = _mm512_fmadd_ps(Constants.ExpLog2Reciprocal, Value, Constants.ExpC); + R = _mm512_sub_ps(R, Constants.ExpC); + + __m512 Fx = _mm512_fmadd_ps(R, Constants.ExpLog2Hi, Value); + Fx = _mm512_fmadd_ps(R, Constants.ExpLog2Lo, Fx); + + __m512 Y = Constants.ExpP0; + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP1); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP2); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP3); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP4); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP5); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP6); + Y = _mm512_scalef_ps(Y, R); + + return Y; +} + +MLAS_FORCEINLINE __m512 +MlasGeluErfAvx512( + __m512 Value, + const GeluAvx512BroadcastConstants& Constants + ) +{ + const __m512 SignMask = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(Value), _mm512_castps_si512(Constants.NegZero))); + __m512 AbsValue = _mm512_castsi512_ps(_mm512_andnot_si512(_mm512_castps_si512(Constants.NegZero), _mm512_castps_si512(Value))); + AbsValue = _mm512_min_ps(Constants.ErfUpperAbsRange, AbsValue); + + const __m512 SquareValue = _mm512_mul_ps(AbsValue, AbsValue); + + __m512 SmallResult = Constants.ErfSmallP0; + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP1); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP2); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP3); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP4); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP5MinusOne); + SmallResult = _mm512_fmadd_ps(SmallResult, AbsValue, AbsValue); + + const __mmask16 SplitMask = _mm512_cmp_ps_mask(AbsValue, Constants.ErfSplitBoundary, _CMP_GT_OQ); + const __m512 BigInput = _mm512_mask_blend_ps(SplitMask, Constants.Zero, AbsValue); + + __m512 BigResult = Constants.ErfBigP0; + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP1); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP2); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP3); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP4); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP5); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP6MinusOne); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, BigInput); + + BigResult = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(BigResult), _mm512_castps_si512(Constants.NegZero))); + BigResult = _mm512_max_ps(Constants.ExpLowerRange, BigResult); + BigResult = _mm512_sub_ps(Constants.ErfOne, MlasGeluErfExpVectorAvx512(BigResult, Constants)); + + __m512 Result = _mm512_mask_blend_ps(SplitMask, SmallResult, BigResult); + Result = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(Result), _mm512_castps_si512(SignMask))); + return Result; +} + +MLAS_FORCEINLINE __m512 +MlasComputeGeluVectorExactAvx512( + __m512 X, + const GeluAvx512BroadcastConstants& Constants + ) +{ + const __m512 ErfInput = _mm512_mul_ps(X, Constants.InvSqrt2); + const __m512 ErfValue = MlasGeluErfAvx512(ErfInput, Constants); + __m512 Result = _mm512_mul_ps(_mm512_mul_ps(Constants.Half, X), _mm512_add_ps(ErfValue, Constants.One)); + + // Preserve NaN payload/sign behavior explicitly because the erf + // approximation uses min/max style range limiting that is not guaranteed to + // preserve NaNs the same way as the existing MLAS GELU semantics. + const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); + Result = _mm512_mask_mov_ps(Result, NaNMask, X); + + return Result; +} + +void +MlasGeluErfKernelAvx512FExactImpl( + const float* Input, + float* Output, + size_t N + ) +{ + const GeluAvx512BroadcastConstants Constants; + while (N >= 16) { + const __m512 X = _mm512_loadu_ps(Input); + const __m512 Result = MlasComputeGeluVectorExactAvx512(X, Constants); + + _mm512_storeu_ps(Output, Result); + + Input += 16; + Output += 16; + N -= 16; + } + + if (N > 0) { + const __mmask16 TailMask = __mmask16((1u << static_cast(N)) - 1u); + const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input); + const __m512 Result = MlasComputeGeluVectorExactAvx512(X, Constants); + + _mm512_mask_storeu_ps(Output, TailMask, Result); + } +} + +} // namespace + +void +MLASCALL +MlasGeluErfKernelAvx512F( + const float* Input, + float* Output, + size_t N + ) +{ + MlasGeluErfKernelAvx512FExactImpl(Input, Output, N); +} diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp new file mode 100644 index 0000000000000..7e8424d94827a --- /dev/null +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -0,0 +1,140 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + silu_avx512f.cpp + +Abstract: + + This module implements routines to compute the SiLU function with AVX512F + intrinsics. + +--*/ + +#include "mlasi.h" + +namespace { + +struct SiluAvx512Constants { + static constexpr float LogisticLowerRange = -18.0f; + static constexpr float LogisticUpperRange = 18.0f; + static constexpr float Alpha9 = 4.37031012579801e-11f; + static constexpr float Alpha7 = 1.15627324459942e-07f; + static constexpr float Alpha5 = 6.08574864600143e-05f; + static constexpr float Alpha3 = 8.51377133304701e-03f; + static constexpr float Alpha1 = 2.48287947061529e-01f; + static constexpr float Beta10 = 6.10247389755681e-13f; + static constexpr float Beta8 = 5.76102136993427e-09f; + static constexpr float Beta6 = 6.29106785017040e-06f; + static constexpr float Beta4 = 1.70198817374094e-03f; + static constexpr float Beta2 = 1.16817656904453e-01f; + static constexpr float Beta0 = 9.93151921023180e-01f; + static constexpr float OneHalf = 0.5f; +}; + +struct SiluAvx512BroadcastConstants { + const __m512 LogisticLowerRange = _mm512_set1_ps(SiluAvx512Constants::LogisticLowerRange); + const __m512 LogisticUpperRange = _mm512_set1_ps(SiluAvx512Constants::LogisticUpperRange); + const __m512 Alpha9 = _mm512_set1_ps(SiluAvx512Constants::Alpha9); + const __m512 Alpha7 = _mm512_set1_ps(SiluAvx512Constants::Alpha7); + const __m512 Alpha5 = _mm512_set1_ps(SiluAvx512Constants::Alpha5); + const __m512 Alpha3 = _mm512_set1_ps(SiluAvx512Constants::Alpha3); + const __m512 Alpha1 = _mm512_set1_ps(SiluAvx512Constants::Alpha1); + const __m512 Beta10 = _mm512_set1_ps(SiluAvx512Constants::Beta10); + const __m512 Beta8 = _mm512_set1_ps(SiluAvx512Constants::Beta8); + const __m512 Beta6 = _mm512_set1_ps(SiluAvx512Constants::Beta6); + const __m512 Beta4 = _mm512_set1_ps(SiluAvx512Constants::Beta4); + const __m512 Beta2 = _mm512_set1_ps(SiluAvx512Constants::Beta2); + const __m512 Beta0 = _mm512_set1_ps(SiluAvx512Constants::Beta0); + const __m512 OneHalf = _mm512_set1_ps(SiluAvx512Constants::OneHalf); + const __m512 Zero = _mm512_setzero_ps(); + const __m512 One = _mm512_set1_ps(1.0f); +}; + +MLAS_FORCEINLINE __m512 +MlasLogisticApproxAvx512( + __m512 Value, + const SiluAvx512BroadcastConstants& Constants + ) +{ + // Mirror MlasComputeLogistic by evaluating the same clamped rational + // approximation in-register and then multiplying by x for SiLU. + const __m512 ClampedValue = _mm512_max_ps(_mm512_min_ps(Value, Constants.LogisticUpperRange), Constants.LogisticLowerRange); + const __m512 ValueSquared = _mm512_mul_ps(ClampedValue, ClampedValue); + + __m512 P = _mm512_fmadd_ps(ValueSquared, Constants.Alpha9, Constants.Alpha7); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha5); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha3); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha1); + P = _mm512_mul_ps(P, ClampedValue); + + __m512 Q = _mm512_fmadd_ps(ValueSquared, Constants.Beta10, Constants.Beta8); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta6); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta4); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta2); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta0); + + __m512 Logistic = _mm512_add_ps(_mm512_div_ps(P, Q), Constants.OneHalf); + Logistic = _mm512_min_ps(_mm512_max_ps(Logistic, Constants.Zero), Constants.One); + + return Logistic; +} + +MLAS_FORCEINLINE __m512 +MlasComputeSiluVectorAvx512( + __m512 X, + const SiluAvx512BroadcastConstants& Constants + ) +{ + __m512 Result = _mm512_mul_ps(X, MlasLogisticApproxAvx512(X, Constants)); + + // Preserve NaN payload/sign behavior explicitly because the clamped + // logistic approximation uses min/max operations that do not reliably + // propagate NaNs the same way as the existing MLAS SiLU semantics. + const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); + Result = _mm512_mask_mov_ps(Result, NaNMask, X); + + return Result; +} + +} // namespace + +void +MLASCALL +MlasSiluKernelAvx512F( + const float* Input, + float* Output, + size_t N + ) +{ + const SiluAvx512BroadcastConstants Constants; + size_t Offset = 0; + + while (Offset + 32 <= N) { + const __m512 X0 = _mm512_loadu_ps(Input + Offset); + const __m512 X1 = _mm512_loadu_ps(Input + Offset + 16); + const __m512 Result0 = MlasComputeSiluVectorAvx512(X0, Constants); + const __m512 Result1 = MlasComputeSiluVectorAvx512(X1, Constants); + _mm512_storeu_ps(Output + Offset, Result0); + _mm512_storeu_ps(Output + Offset + 16, Result1); + Offset += 32; + } + + while (Offset + 16 <= N) { + const __m512 X = _mm512_loadu_ps(Input + Offset); + const __m512 Result = MlasComputeSiluVectorAvx512(X, Constants); + _mm512_storeu_ps(Output + Offset, Result); + Offset += 16; + } + + if (Offset < N) { + const __mmask16 TailMask = static_cast<__mmask16>((1u << (N - Offset)) - 1u); + const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input + Offset); + const __m512 Result = MlasComputeSiluVectorAvx512(X, Constants); + _mm512_mask_storeu_ps(Output + Offset, TailMask, Result); + } +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 954849fe90049..0dab8e41f25cd 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1096,6 +1096,8 @@ extern "C" { #endif MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernel; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluErfKernel; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasLogisticKernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasTanhKernel; @@ -1126,6 +1128,8 @@ extern "C" { MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8KernelAvx2; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelAvx512F; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelAvx512F; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluErfKernelAvx512F; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernelAvx512F; #endif MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32Kernel; @@ -1477,6 +1481,8 @@ struct MLAS_PLATFORM { MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; #endif #if defined(MLAS_TARGET_AMD64) + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluErfKernelRoutine; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* SiluKernelRoutine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine; MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* TransposePackB16x4Routine; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index ac3761d63bd20..eccde79848e61 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -283,7 +283,9 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelSse; this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelSse; this->ComputeExpF32Kernel = MlasComputeExpF32Kernel; + this->GeluErfKernelRoutine = MlasGeluErfKernel; this->LogisticKernelRoutine = MlasLogisticKernel; + this->SiluKernelRoutine = MlasSiluKernel; this->TanhKernelRoutine = MlasTanhKernel; this->ErfKernelRoutine = MlasErfKernel; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32Kernel; @@ -459,7 +461,8 @@ Return Value: // if (((Cpuid7[1] & 0x10000) != 0) && ((xcr0 & 0xE0) == 0xE0)) { - + this->GeluErfKernelRoutine = MlasGeluErfKernelAvx512F; + this->SiluKernelRoutine = MlasSiluKernelAvx512F; this->GemmFloatKernel = MlasGemmFloatKernelAvx512F; this->GemmDoubleKernel = MlasGemmDoubleKernelAvx512F; this->ConvNchwFloatKernel = MlasConvNchwFloatKernelAvx512F; diff --git a/onnxruntime/core/mlas/lib/silu.cpp b/onnxruntime/core/mlas/lib/silu.cpp new file mode 100644 index 0000000000000..96686e4bdf1da --- /dev/null +++ b/onnxruntime/core/mlas/lib/silu.cpp @@ -0,0 +1,51 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + silu.cpp + +Abstract: + + This module implements routines to compute the SiLU function. + +--*/ + +#include "mlasi.h" + +void +MLASCALL +MlasSiluKernel( + const float* Input, + float* Output, + size_t N + ) +{ + // This kernel is not buffer alias safe because it is implemented in two + // passes: first compute logistic(Input) into Output, then multiply that + // intermediate by the original Input values. Callers must guarantee that + // Input and Output do not overlap (see mlas.h for aliasing requirements). + MlasComputeLogistic(Input, Output, N); + MlasEltwiseMul(Input, Output, Output, N); +} + +void +MLASCALL +MlasComputeSilu( + const float* Input, + float* Output, + size_t N + ) +{ +#if defined(MLAS_TARGET_AMD64) + // TODO: Add an intermediate fused AVX2/FMA3 SiLU path on AMD64. Today the + // dispatch jumps from the generic two-pass implementation to AVX512F, so + // non-AVX512 x64 machines fall back to the generic kernel. + GetMlasPlatform().SiluKernelRoutine(Input, Output, N); +#else + MlasSiluKernel(Input, Output, N); +#endif +} diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index d55973eda180f..e34af83d1f29e 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -88,16 +88,9 @@ Status Gelu::Compute(OpKernelContext* context) const { T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - for (int64_t i = 0; i < count; i++) { - T value = p_input[i]; - p_output[i] = value * static_cast(M_SQRT1_2); - } - - MlasComputeErf(p_output, p_output, narrow(count)); - - for (int64_t i = 0; i < count; i++) { - p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); - } + // MlasComputeGeluErf requires distinct input/output buffers. This + // call uses disjoint slices from the input and output tensors. + MlasComputeGeluErf(p_input, p_output, narrow(count)); }, 0); return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h index 13238028d878a..14a070609a69b 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.h +++ b/onnxruntime/core/providers/cpu/tensor/gelu.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once + namespace onnxruntime { template diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp new file mode 100644 index 0000000000000..f7e461c29843a --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include + +#include "mlas.h" +#include "bench_util.h" +#include "core/mlas/lib/mlasi.h" + +namespace { + +// Compare fused MLAS unary activation paths against unfused baselines for +// SiLU and exact GELU(erf). + +constexpr float kSiluMinValue = -20.0f; +constexpr float kSiluMaxValue = 20.0f; +constexpr float kGeluMinValue = -10.0f; +constexpr float kGeluMaxValue = 10.0f; +constexpr float kInvSqrt2 = 0.7071067811865475244f; +constexpr int64_t kFusedBytesPerElement = 2; +constexpr int64_t kSiluUnfusedBytesPerElement = 5; +constexpr int64_t kGeluUnfusedBytesPerElement = 7; + +struct DispatchedUnaryPathInfo { + int64_t bytes_per_element; + const char* label; +}; + +DispatchedUnaryPathInfo GetSiluDispatchPathInfo() { +#if defined(MLAS_TARGET_AMD64) + if (GetMlasPlatform().SiluKernelRoutine == MlasSiluKernelAvx512F) { + return {kFusedBytesPerElement, "avx512_fused"}; + } +#endif + + // The current non-AVX512 dispatch target falls back to the generic path, + // which materializes the logistic result before the final multiply. + return {kSiluUnfusedBytesPerElement, "generic_fallback"}; +} + +DispatchedUnaryPathInfo GetGeluErfDispatchPathInfo() { +#if defined(MLAS_TARGET_AMD64) + if (GetMlasPlatform().GeluErfKernelRoutine == MlasGeluErfKernelAvx512F) { + return {kFusedBytesPerElement, "avx512_fused"}; + } +#endif + + // The current non-AVX512 dispatch target falls back to the generic exact + // GELU(erf) implementation, which uses separate scale, erf, and final passes. + return {kGeluUnfusedBytesPerElement, "generic_fallback"}; +} + +std::vector MakeInput(size_t n, float min_value, float max_value) { + auto data = RandomVectorUniform(n, min_value, max_value); + + if (!data.empty()) { + data[0] = 0.0f; + } + if (data.size() > 1) { + data[1] = -0.0f; + } + if (data.size() > 2) { + data[2] = -1.0f; + } + if (data.size() > 3) { + data[3] = 1.0f; + } + + return data; +} + +template +void RunDispatchedUnaryBenchmark(benchmark::State& state, + KernelFn&& kernel, + float min_value, + float max_value, + DispatchedUnaryPathInfo path_info) { + const auto n = static_cast(state.range(0)); + auto input = MakeInput(n, min_value, max_value); + std::vector output(n); + + state.SetLabel(path_info.label); + + kernel(input.data(), output.data(), n); + + for (auto _ : state) { + kernel(input.data(), output.data(), n); + benchmark::DoNotOptimize(output.data()); + benchmark::ClobberMemory(); + } + + const int64_t bytes_per_iteration = static_cast(n) * static_cast(sizeof(float)) * path_info.bytes_per_element; + state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); + state.SetBytesProcessed(static_cast(state.iterations()) * bytes_per_iteration); +} + +template +void RunUnfusedUnaryBenchmark(benchmark::State& state, + KernelFn&& kernel, + float min_value, + float max_value, + int64_t bytes_per_element) { + const auto n = static_cast(state.range(0)); + auto input = MakeInput(n, min_value, max_value); + std::vector output(n); + + kernel(input.data(), output.data(), n); + + for (auto _ : state) { + kernel(input.data(), output.data(), n); + benchmark::DoNotOptimize(output.data()); + benchmark::ClobberMemory(); + } + + const int64_t bytes_per_iteration = static_cast(n) * static_cast(sizeof(float)) * bytes_per_element; + state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); + state.SetBytesProcessed(static_cast(state.iterations()) * bytes_per_iteration); +} + +static void UnaryKernelArgs(benchmark::internal::Benchmark* b) { + for (int n : {1, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 511, 512, 1024, 4096, 16384, 65536, 262144}) { + b->Arg(n); + } +} + +void BM_SiluDispatch(benchmark::State& state) { + // Fused MLAS SiLU entry point. On supported platforms this may dispatch to a + // specialized implementation that combines the activation into a single + // kernel instead of exposing intermediate results. + RunDispatchedUnaryBenchmark(state, MlasComputeSilu, kSiluMinValue, kSiluMaxValue, GetSiluDispatchPathInfo()); +} + +void BM_SiluUnfusedDispatch(benchmark::State& state) { + // Unfused SiLU baseline: compute logistic(x) first and then multiply by x in + // a separate elementwise pass. + RunUnfusedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + MlasComputeLogistic(input, output, n); + MlasEltwiseMul(input, output, output, n); + }, + kSiluMinValue, + kSiluMaxValue, + kSiluUnfusedBytesPerElement); +} + +void BM_GeluErfDispatchExact(benchmark::State& state) { + // Fused MLAS GELU(erf) entry point using the exact erf-based formulation. + // On AMD64 this goes through the platform dispatch layer and may select an + // architecture-specific implementation. + RunDispatchedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + MlasComputeGeluErf(input, output, n); + }, + kGeluMinValue, + kGeluMaxValue, + GetGeluErfDispatchPathInfo()); +} + +void BM_GeluErfUnfusedExact(benchmark::State& state) { + // Unfused exact GELU(erf) baseline: scale by 1/sqrt(2), run erf, then apply the + // final 0.5 * x * (erf(x / sqrt(2)) + 1) transform in a separate pass. + RunUnfusedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + for (size_t i = 0; i < n; ++i) { + output[i] = input[i] * kInvSqrt2; + } + + MlasComputeErf(output, output, n); + + for (size_t i = 0; i < n; ++i) { + output[i] = 0.5f * input[i] * (output[i] + 1.0f); + } + }, + kGeluMinValue, + kGeluMaxValue, + kGeluUnfusedBytesPerElement); +} + +} // namespace + +BENCHMARK(BM_SiluDispatch)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_SiluUnfusedDispatch)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_GeluErfDispatchExact)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_GeluErfUnfusedExact)->Apply(UnaryKernelArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp new file mode 100644 index 0000000000000..e87768ce3e660 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -0,0 +1,285 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" + +#include +#include + +#if defined(MLAS_TARGET_AMD64) + +namespace { + +constexpr float kGeluMinValue = -10.0f; +constexpr float kGeluMaxValue = 10.0f; +constexpr float kSiluMinValue = -20.0f; +constexpr float kSiluMaxValue = 20.0f; + +constexpr float kGeluAbsoluteTolerance = 2e-6f; +constexpr float kGeluRelativeTolerance = 2e-5f; +constexpr float kSiluAbsoluteTolerance = 3e-5f; +constexpr float kSiluRelativeTolerance = 5e-5f; + +constexpr std::array kShortTestSizes = { + 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255}; + +constexpr std::array kLongTestSizes = { + 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, + 64, 65, 127, 128, 129, 255, 511, 512, 513, 1023, 1024, 1025, 4095}; + +bool IsGeluErfAvx512Dispatched() { + return GetMlasPlatform().GeluErfKernelRoutine == MlasGeluErfKernelAvx512F; +} + +bool IsSiluAvx512Dispatched() { + return GetMlasPlatform().SiluKernelRoutine == MlasSiluKernelAvx512F; +} + +bool UnaryOutputsMatch(float actual, float expected, float absolute_tolerance, float relative_tolerance, + bool check_signed_zero) { + if (std::isnan(expected)) { + return std::isnan(actual); + } + + if (std::isinf(expected)) { + return std::isinf(actual) && (std::signbit(actual) == std::signbit(expected)); + } + + if (check_signed_zero && actual == 0.0f && expected == 0.0f) { + return std::signbit(actual) == std::signbit(expected); + } + + const float diff = std::fabs(actual - expected); + if (diff <= absolute_tolerance) { + return true; + } + + const float scale = std::max(std::fabs(actual), std::fabs(expected)); + return scale > 0.0f && diff <= scale * relative_tolerance; +} + +const std::vector& GetGeluSpecialValues() { + static const std::vector values = { + std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + 0.0f, + -0.0f, + -10.0f, + -6.0f, + -3.0f, + -1.0f, + -0.5f, + 0.5f, + 1.0f, + 3.0f, + 6.0f, + 10.0f, + }; + + return values; +} + +const std::vector& GetSiluSpecialValues() { + static const std::vector values = { + std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::numeric_limits::max(), + -std::numeric_limits::max(), + 1.0e9f, + -1.0e9f, + 0.0f, + -0.0f, + -20.0f, + -10.0f, + -6.0f, + -3.0f, + -1.0f, + -0.5f, + 0.5f, + 1.0f, + 3.0f, + 6.0f, + 10.0f, + 20.0f, + }; + + return values; +} + +void FillInput(float* input, size_t n, float minimum_value, float maximum_value, + const std::vector& special_values, uint32_t seed) { + std::mt19937 generator(seed); + std::uniform_real_distribution distribution(minimum_value, maximum_value); + + for (size_t i = 0; i < n; ++i) { + input[i] = distribution(generator); + } + + const size_t special_count = std::min(n, special_values.size()); + for (size_t i = 0; i < special_count; ++i) { + input[i] = special_values[i]; + } +} + +class MlasComputeGeluErfAvx512Test : public MlasTestBase { + private: + MatrixGuardBuffer input_buffer_; + MatrixGuardBuffer generic_output_buffer_; + MatrixGuardBuffer public_output_buffer_; + MatrixGuardBuffer avx512_output_buffer_; + + void ExecuteCommon(const std::vector& sizes, size_t iterations) { + if (!IsGeluErfAvx512Dispatched()) { + GTEST_SKIP() << "AVX512F GELU(erf) dispatch is not available on this machine."; + } + + for (size_t size : sizes) { + for (size_t iteration = 0; iteration < iterations; ++iteration) { + float* input = input_buffer_.GetBuffer(size); + float* generic_output = generic_output_buffer_.GetBuffer(size); + float* public_output = public_output_buffer_.GetBuffer(size); + float* avx512_output = avx512_output_buffer_.GetBuffer(size); + + FillInput(input, size, kGeluMinValue, kGeluMaxValue, GetGeluSpecialValues(), + static_cast(size * 131u + iteration * 977u + 17u)); + + MlasGeluErfKernel(input, generic_output, size); + MlasComputeGeluErf(input, public_output, size); + MlasGeluErfKernelAvx512F(input, avx512_output, size); + + for (size_t i = 0; i < size; ++i) { + ASSERT_TRUE(UnaryOutputsMatch(public_output[i], generic_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "Public GELU(erf) mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", public=" << public_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(public_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "GELU(erf) mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], public_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "Public/API GELU(erf) dispatch mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", public=" << public_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - public_output[i]); + } + } + } + } + + public: + static const char* GetTestSuiteName() { + return "TranscendentalAvx512Gelu"; + } + + void ExecuteShort() override { + ExecuteCommon(std::vector(kShortTestSizes.begin(), kShortTestSizes.end()), 3); + } + + void ExecuteLong() override { + ExecuteCommon(std::vector(kLongTestSizes.begin(), kLongTestSizes.end()), 8); + } +}; + +class MlasComputeSiluAvx512Test : public MlasTestBase { + private: + MatrixGuardBuffer input_buffer_; + MatrixGuardBuffer generic_output_buffer_; + MatrixGuardBuffer public_output_buffer_; + MatrixGuardBuffer avx512_output_buffer_; + + void ExecuteCommon(const std::vector& sizes, size_t iterations) { + if (!IsSiluAvx512Dispatched()) { + GTEST_SKIP() << "AVX512F SiLU dispatch is not available on this machine."; + } + + for (size_t size : sizes) { + for (size_t iteration = 0; iteration < iterations; ++iteration) { + float* input = input_buffer_.GetBuffer(size); + float* generic_output = generic_output_buffer_.GetBuffer(size); + float* public_output = public_output_buffer_.GetBuffer(size); + float* avx512_output = avx512_output_buffer_.GetBuffer(size); + + FillInput(input, size, kSiluMinValue, kSiluMaxValue, GetSiluSpecialValues(), + static_cast(size * 149u + iteration * 991u + 31u)); + + MlasSiluKernel(input, generic_output, size); + MlasComputeSilu(input, public_output, size); + MlasSiluKernelAvx512F(input, avx512_output, size); + + for (size_t i = 0; i < size; ++i) { + ASSERT_TRUE(UnaryOutputsMatch(public_output[i], generic_output[i], + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Public Silu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", public=" << public_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(public_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Silu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], public_output[i], + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Public/API Silu dispatch mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", public=" << public_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - public_output[i]); + } + } + } + } + + public: + static const char* GetTestSuiteName() { + return "TranscendentalAvx512Silu"; + } + + void ExecuteShort() override { + ExecuteCommon(std::vector(kShortTestSizes.begin(), kShortTestSizes.end()), 3); + } + + void ExecuteLong() override { + ExecuteCommon(std::vector(kLongTestSizes.begin(), kLongTestSizes.end()), 8); + } +}; + +} // namespace + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } else { + count += MlasLongExecuteTests::RegisterLongExecute(); + count += MlasLongExecuteTests::RegisterLongExecute(); + } + return count; +}); + +#else + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool) { + return size_t{0}; +}); + +#endif From 99c5dd8839c504fb86c47d6e0b7ad1b56d5d6c8f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 24 Mar 2026 12:04:33 -0700 Subject: [PATCH 14/17] Make WebGPU EP compatible with EP API (#26907) ### Description This PR makes it possible to build WebGPU EP as an EP API based plugin EP. #### Requirements The goal of this PR is to support both building WebGPU EP as a bundled EP and an EP API based plugin EP. This approach allows: - enabling WebGPU EP as a standalone plugin EP package for WCR usage - graceful transition for WebGPU EP as an native EP for language binding, from the bundled EP to an EP API based plugin EP - keep the existing usage (static library) working (majorly for web) #### Design & Implementation Instead of **changing** WebGPU EP from a bundled EP to an EP API based plugin EP in one shot, this PR **extend** WebGPU EP to support building as plugin EP. - add a new folder `include/onnxruntime/ep` with a bunches of header files. Those files are not WebGPU specific. They are used for: - include common defines/functions/macros for plugin EP to use - include a few "adapter" classes that takes C-API objects to simulate ORT internal classes behaviors - include a few "override" classes that simulate ORT internal classes, but using implementations that only depend on C-API - include a special base class `onnxruntime::ep::Ep` to inherit from These header files allow a compile time "switch" to the different set of types to minimize changes to existing code. Specifically, `pch.h` is required to be included as PCH to make sure the "override" to take place correctly. - add a new folder `onnxruntime/core/providers/webgpu/ep` for EP API implementation, specifically: - `api.cc`: implements `CreateEpFactories` and `ReleaseEpFactory` - `ep.cc` `ep.h`: implement class `onnxruntime::webgpu::ep::Ep` - `factory.cc` `factory.h`: implement class `onnxruntime::webgpu::ep::Factory` #### Dependencies and Prerequisites (unmerged changes are included as a part of current PR) - https://github.com/microsoft/onnxruntime/pull/26855 - https://github.com/microsoft/onnxruntime/pull/26803 - https://github.com/microsoft/onnxruntime/pull/26859 - https://github.com/microsoft/onnxruntime/pull/26879 - https://github.com/microsoft/onnxruntime/pull/26919 - https://github.com/microsoft/onnxruntime/pull/27569 - https://github.com/microsoft/onnxruntime/pull/27587 #### Missing Parts - Allow setting Global/Default EP options --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .github/workflows/windows_webgpu.yml | 80 +++++ cmake/onnxruntime_providers_webgpu.cmake | 8 +- cmake/onnxruntime_unittests.cmake | 13 + include/onnxruntime/ep/adapter/allocator.h | 37 ++- include/onnxruntime/ep/adapter/ep.h | 1 + .../onnxruntime/ep/adapter/op_kernel_info.h | 2 +- include/onnxruntime/ep/common.h | 21 ++ .../core/providers/cpu/tensor/upsamplebase.h | 13 +- .../core/providers/webgpu/compute_context.h | 4 + .../core/providers/webgpu/controlflow/if.cc | 16 +- .../core/providers/webgpu/controlflow/if.h | 15 +- .../core/providers/webgpu/data_transfer.cc | 50 ++-- .../core/providers/webgpu/data_transfer.h | 21 +- onnxruntime/core/providers/webgpu/ep/api.cc | 84 ++++++ onnxruntime/core/providers/webgpu/ep/ep.cc | 273 ++++++++++++++++++ onnxruntime/core/providers/webgpu/ep/ep.h | 77 +++++ .../core/providers/webgpu/ep/factory.cc | 206 +++++++++++++ .../core/providers/webgpu/ep/factory.h | 76 +++++ .../core/providers/webgpu/generator/range.cc | 54 ++-- .../core/providers/webgpu/tensor/cast.cc | 2 +- .../core/providers/webgpu/tensor/concat.cc | 2 +- .../core/providers/webgpu/tensor/expand.cc | 4 +- .../core/providers/webgpu/tensor/gather.cc | 2 +- .../core/providers/webgpu/tensor/pad.cc | 2 +- .../core/providers/webgpu/tensor/shape_op.cc | 64 +++- .../core/providers/webgpu/tensor/unsqueeze.cc | 4 +- .../core/providers/webgpu/tensor/upsample.cc | 2 +- .../webgpu/webgpu_execution_provider.cc | 88 +++++- .../webgpu/webgpu_execution_provider.h | 32 +- .../webgpu/webgpu_provider_factory.cc | 57 +++- .../webgpu/webgpu_provider_factory_creator.h | 2 +- .../test/contrib_ops/matmul_2bits_test.cc | 4 +- onnxruntime/test/unittest_util/base_tester.cc | 2 - .../unittest_util/test_dynamic_plugin_ep.cc | 10 +- .../unittest_util/test_dynamic_plugin_ep.h | 4 +- onnxruntime/test/util/default_providers.cc | 41 ++- 36 files changed, 1261 insertions(+), 112 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/ep/api.cc create mode 100644 onnxruntime/core/providers/webgpu/ep/ep.cc create mode 100644 onnxruntime/core/providers/webgpu/ep/ep.h create mode 100644 onnxruntime/core/providers/webgpu/ep/factory.cc create mode 100644 onnxruntime/core/providers/webgpu/ep/factory.h diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 872d3b182c310..e67eda41d2e0e 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -155,6 +155,86 @@ jobs: working-directory: ${{ github.workspace }}\csharp continue-on-error: true + webgpu_plugin_build_x64_RelWithDebInfo: + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "JobId=webgpu_plugin_build_x64_RelWithDebInfo-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] + timeout-minutes: 300 + env: + OnnxRuntimeBuildDirectory: ${{ github.workspace }} + setVcvars: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: "0" + DocUpdateNeeded: false + NVIDIA_TF32_OVERRIDE: "0" + ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" + steps: + - name: Checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + submodules: none + + - name: Setup Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: "3.12" + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + run: python -m pip install -r tools\ci_build\github\windows\python\requirements.txt + shell: cmd + working-directory: ${{ github.workspace }} + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: "20.x" + + - uses: actions/cache@v5 + id: onnx-node-tests-cache + with: + path: ${{ github.workspace }}/js/test/ + key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} + + - name: Build and Test + shell: pwsh + run: | + python.exe ${{ github.workspace }}\tools\ci_build\build.py ` + --config RelWithDebInfo ` + --build_dir ${{ github.workspace }} ` + --skip_submodule_sync ` + --parallel ` + --use_binskim_compliant_compile_flags ` + --cmake_generator "Visual Studio 17 2022" ` + --enable_onnx_tests ` + --use_webgpu shared_lib ` + --wgsl_template static ` + --use_vcpkg --use_vcpkg_ms_internal_asset_cache ` + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_DAWN_BACKEND_D3D12=1 onnxruntime_ENABLE_DAWN_BACKEND_VULKAN=1 ` + --disable_rtti ` + --enable_lto + + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + - name: Publish artifacts + uses: actions/upload-artifact@v4 + with: + name: webgpu-plugin-binaries + path: | + ${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_webgpu.dll + ${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_webgpu.pdb + ${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/dxcompiler.dll + ${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/dxil.dll + webgpu_external_dawn_build_x64_RelWithDebInfo: runs-on: [ "self-hosted", diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index be7f2613a6272..cd29e4dad0a17 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -56,7 +56,7 @@ endif() source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs}) - onnxruntime_add_shared_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) + onnxruntime_add_shared_library_module(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_webgpu ${REPO_ROOT}/include/onnxruntime/core/session onnxruntime_common @@ -119,6 +119,12 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") message(FATAL_ERROR "WebGPU EP shared library build is not supported on Emscripten. Please use static library build.") endif() + + # Configure precompiled headers for shared library build + # PCH ensures ep/adapters.h is included first and improves compilation speed + target_precompile_headers(onnxruntime_providers_webgpu PRIVATE + "${REPO_ROOT}/include/onnxruntime/ep/adapters.h" + ) endif() set_target_properties(onnxruntime_providers_webgpu PROPERTIES CXX_STANDARD_REQUIRED ON) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 9ae3e79d86443..8137f8b3a2529 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1042,6 +1042,18 @@ function(onnxruntime_apply_test_target_workarounds target) endif() endfunction() +# Set environment variables for plugin EP tests when run via CTest. +function(onnxruntime_set_plugin_ep_test_environment target) + if(onnxruntime_USE_WEBGPU AND onnxruntime_USE_EP_API_ADAPTERS) + set(ORT_PLUGIN_EP_JSON_CONFIG "{\"ep_library_registration_name\": \"WebGPU_PluginEP\", \"ep_library_path\": \"$\", \"selected_ep_name\": \"WebGpuExecutionProvider\"}") + set_tests_properties(${target} PROPERTIES + ENVIRONMENT "ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON=${ORT_PLUGIN_EP_JSON_CONFIG}" + ) + # TODO: add for other plugin EPs if needed + # elseif() + endif() +endfunction() + function(onnxruntime_apply_emscripten_test_link_settings target) if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set_target_properties(${target} PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js) @@ -1250,6 +1262,7 @@ block() ) onnxruntime_apply_test_target_workarounds(onnxruntime_provider_test) + onnxruntime_set_plugin_ep_test_environment(onnxruntime_provider_test) # Expose QNN SDK headers to unit tests via an interface target if(onnxruntime_USE_QNN) diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h index 2765069ebf336..4f107ae72c0e9 100644 --- a/include/onnxruntime/ep/adapter/allocator.h +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -18,23 +18,50 @@ namespace adapter { /// class Allocator : public OrtAllocator { public: + /** + * Create from an existing AllocatorPtr. + */ explicit Allocator(const OrtMemoryInfo* memory_info, AllocatorPtr impl) - : OrtAllocator{}, memory_info_(memory_info), impl_(impl) { + : Allocator{memory_info} { + ORT_ENFORCE(impl != nullptr, "Allocator implementation cannot be null."); + impl_ = impl; + } + + using AllocatorFactory = AllocatorPtr (*)(const OrtMemoryInfo& memory_info); + + /** + * Create from an AllocatorFactory, which will be called lazily when the first allocation is made. + */ + explicit Allocator(const OrtMemoryInfo* memory_info, AllocatorFactory get_allocator_impl) + : Allocator{memory_info} { + get_allocator_impl_ = get_allocator_impl; + } + + private: + explicit Allocator(const OrtMemoryInfo* memory_info) + : OrtAllocator{}, memory_info_(memory_info) { version = ORT_API_VERSION; Alloc = AllocImpl; Free = FreeImpl; Info = InfoImpl; } + AllocatorPtr GetImpl() { + if (!impl_) { + std::call_once(init_flag_, [this]() { + impl_ = get_allocator_impl_(*memory_info_); + }); + } + return impl_; + } - private: static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept { auto* allocator = static_cast(this_ptr); - return allocator->impl_->Alloc(size); + return allocator->GetImpl()->Alloc(size); } static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept { auto* allocator = static_cast(this_ptr); - allocator->impl_->Free(p); + allocator->GetImpl()->Free(p); } static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept { @@ -44,6 +71,8 @@ class Allocator : public OrtAllocator { const OrtMemoryInfo* memory_info_; AllocatorPtr impl_; + AllocatorFactory get_allocator_impl_; + std::once_flag init_flag_; }; } // namespace adapter diff --git a/include/onnxruntime/ep/adapter/ep.h b/include/onnxruntime/ep/adapter/ep.h index 34fc7682a8138..ca0a8c9599eda 100644 --- a/include/onnxruntime/ep/adapter/ep.h +++ b/include/onnxruntime/ep/adapter/ep.h @@ -27,6 +27,7 @@ class Ep : public OrtEp { profiler_{impl_->GetProfiler()}, temp_space_cpu_allocator_{temp_space_cpu_allocator}, temp_space_allocator_{temp_space_allocator} { + ort_version_supported = ORT_API_VERSION; } public: diff --git a/include/onnxruntime/ep/adapter/op_kernel_info.h b/include/onnxruntime/ep/adapter/op_kernel_info.h index bd6172a668e33..644cb30788ec6 100644 --- a/include/onnxruntime/ep/adapter/op_kernel_info.h +++ b/include/onnxruntime/ep/adapter/op_kernel_info.h @@ -19,7 +19,7 @@ #include "tensor_helper.h" namespace onnxruntime { -struct DataTransferManager; +class DataTransferManager; struct IExecutionProvider; } // namespace onnxruntime diff --git a/include/onnxruntime/ep/common.h b/include/onnxruntime/ep/common.h index 03cd571461755..0e779ba3d4081 100644 --- a/include/onnxruntime/ep/common.h +++ b/include/onnxruntime/ep/common.h @@ -41,3 +41,24 @@ OrtStatus* _status = (status_expr); \ Ort::Status _ignored{_status}; \ } while (false) + +// Helper macros to convert exceptions to OrtStatus* return values. +// Usage: +// EXCEPTION_TO_RETURNED_STATUS_BEGIN +// ... code that may throw ... +// EXCEPTION_TO_RETURNED_STATUS_END +#define EXCEPTION_TO_RETURNED_STATUS_BEGIN try { +#define EXCEPTION_TO_RETURNED_STATUS_END \ + } \ + catch (const Ort::Exception& ex) { \ + Ort::Status status(ex); \ + return status.release(); \ + } \ + catch (const std::exception& ex) { \ + Ort::Status status(ex.what(), ORT_EP_FAIL); \ + return status.release(); \ + } \ + catch (...) { \ + Ort::Status status("Unknown exception", ORT_EP_FAIL); \ + return status.release(); \ + } diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index 7dcf88133e967..ff5498c0b4644 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -265,8 +265,17 @@ class UpsampleBase { if (scales_input_idx_ > 0) { const Tensor* scale; bool get_scale = info.TryGetConstantInput(scales_input_idx_, &scale); - auto x_shape = node.InputDefs()[0]->Shape(); - int64_t rank = x_shape ? x_shape->dim_size() : -1; + int64_t rank = -1; + if constexpr (std::is_same_v) { + auto x_shape = node.InputDefs()[0]->Shape(); + if (x_shape != nullptr) { + rank = x_shape->dim_size(); + } + } else { + auto type_info = info.GetKernelInfo().GetInputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + rank = static_cast(tensor_info.GetDimensionsCount()); + } if (get_scale && scale->Shape().Size() > 0 && ((opset < 18) || (rank > 0 && opset >= 18))) { ORT_THROW_IF_ERROR(ParseScalesData(scale, scales_, rank)); scales_cached_ = true; diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 5277d64ad3611..38848e98509ba 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -107,7 +107,11 @@ class ComputeContextBase { // Get the logger. // inline const logging::Logger& Logger() const { +#if defined(ORT_USE_EP_API_ADAPTERS) + return ep_.GetEpLogger(); +#else return *ep_.GetLogger(); +#endif } // diff --git a/onnxruntime/core/providers/webgpu/controlflow/if.cc b/onnxruntime/core/providers/webgpu/controlflow/if.cc index 233d1f760383f..29b5e66d5075a 100644 --- a/onnxruntime/core/providers/webgpu/controlflow/if.cc +++ b/onnxruntime/core/providers/webgpu/controlflow/if.cc @@ -3,6 +3,10 @@ #include "core/providers/webgpu/controlflow/if.h" +#if defined(ORT_USE_EP_API_ADAPTERS) +#include "core/framework/error_code_helper.h" +#endif + using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; @@ -68,10 +72,20 @@ ONNX_OPERATOR_KERNEL_EX(If, .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), If); +#if !defined(ORT_USE_EP_API_ADAPTERS) Status If::Compute(OpKernelContext* ctx) const { // call the base CPU version. return onnxruntime::If::Compute(ctx); } +#else +Status If::CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + return ToStatusAndRelease(ep::Api().ep.CreateIfKernel(info, impl)); +} + +Status If::Compute(OpKernelContext* /*ctx*/) const { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "If operator should be handled by ORT core."); +} +#endif } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/controlflow/if.h b/onnxruntime/core/providers/webgpu/controlflow/if.h index 0755c5d33d7a3..0aa30e939eb1a 100644 --- a/onnxruntime/core/providers/webgpu/controlflow/if.h +++ b/onnxruntime/core/providers/webgpu/controlflow/if.h @@ -10,6 +10,8 @@ namespace onnxruntime { namespace webgpu { +#if !defined(ORT_USE_EP_API_ADAPTERS) + // Use the CPU implementation for the logic class If final : public onnxruntime::If { public: @@ -18,5 +20,16 @@ class If final : public onnxruntime::If { Status Compute(OpKernelContext* ctx) const override; }; +#else + +class If final : public OpKernel { + public: + If(const OpKernelInfo& info) : OpKernel(info) {} + + Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) override; + Status Compute(OpKernelContext* ctx) const override; +}; +#endif + } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc index 6d66a7308f1de..5f109bf73e3c5 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.cc +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -7,38 +7,48 @@ namespace onnxruntime { namespace webgpu { -bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || - (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::GPU) || - (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); -} - -common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { - size_t bytes = src.SizeInBytes(); +common::Status DataTransferImpl::CopyTensor(void const* src_data, + bool src_is_gpu, + void* dst_data, + bool dst_is_gpu, + size_t bytes) const { if (bytes > 0) { - void const* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - - auto& src_device = src.Location().device; - auto& dst_device = dst.Location().device; - - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { + if (dst_is_gpu) { + if (src_is_gpu) { // copy from GPU to GPU buffer_manager_.MemCpy(static_cast(const_cast(src_data)), - static_cast(dst_data), bytes); + static_cast(dst_data), + bytes); } else { // copy from CPU to GPU - buffer_manager_.Upload(const_cast(src_data), static_cast(dst_data), bytes); + buffer_manager_.Upload(const_cast(src_data), + static_cast(dst_data), + bytes); } - } else /* if (src_device.Type() == OrtDevice::GPU) */ { + } else { // copy from GPU to CPU - buffer_manager_.Download(static_cast(const_cast(src_data)), dst_data, bytes); + buffer_manager_.Download(static_cast(const_cast(src_data)), + dst_data, + bytes); } } return Status::OK(); } +bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || + (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::GPU) || + (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); +} + +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + return impl_.CopyTensor(src.DataRaw(), + src.Location().device.Type() == OrtDevice::GPU, + dst.MutableDataRaw(), + dst.Location().device.Type() == OrtDevice::GPU, + src.SizeInBytes()); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h index 0adf380149acf..e6ce92a7ca7a6 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.h +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -3,6 +3,7 @@ #pragma once +#include "core/common/status.h" #include "core/framework/data_transfer.h" #include "core/framework/execution_provider.h" @@ -11,9 +12,25 @@ namespace webgpu { class BufferManager; +// Low-level data transfer implementation that operates on raw pointers. +// Used by both DataTransfer (IDataTransfer subclass) and the C API data transfer wrapper. +class DataTransferImpl { + public: + DataTransferImpl(const BufferManager& buffer_manager) : buffer_manager_{buffer_manager} {}; + + common::Status CopyTensor(void const* src_data, + bool src_is_gpu, + void* dst_data, + bool dst_is_gpu, + size_t bytes) const; + + private: + const BufferManager& buffer_manager_; +}; + class DataTransfer : public IDataTransfer { public: - DataTransfer(const BufferManager& buffer_manager) : buffer_manager_{buffer_manager} {}; + DataTransfer(const BufferManager& buffer_manager) : impl_{buffer_manager} {}; ~DataTransfer() {}; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; @@ -21,7 +38,7 @@ class DataTransfer : public IDataTransfer { common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; private: - const BufferManager& buffer_manager_; + DataTransferImpl impl_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/ep/api.cc b/onnxruntime/core/providers/webgpu/ep/api.cc new file mode 100644 index 0000000000000..9eeb3d71df89f --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/api.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include + +#include "core/providers/webgpu/ep/factory.h" + +// To make symbols visible on macOS/iOS +#ifdef __APPLE__ +#define EXPORT_SYMBOL __attribute__((visibility("default"))) +#else +#define EXPORT_SYMBOL +#endif + +namespace onnxruntime { +namespace webgpu { +void CleanupWebGpuContexts(); +void CleanupKernelRegistries(); +} // namespace webgpu +} // namespace onnxruntime + +namespace google { +namespace protobuf { +void ShutdownProtobufLibrary(); +} // namespace protobuf +} // namespace google + +extern "C" { +// +// Public symbols +// +EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + const OrtLogger* default_logger, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + + // Manual init for the C++ API + onnxruntime::ep::ApiInit(ort_api_base); + + if (max_factories < 1) { + return onnxruntime::ep::Api().ort.CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + // Initialize the global default logger + ::onnxruntime::ep::adapter::LoggingManager::CreateDefaultLogger(default_logger); + + // Factory could use registration_name or define its own EP name. + std::unique_ptr factory = std::make_unique(); + + factories[0] = factory.release(); + *num_factories = 1; + + return nullptr; + + EXCEPTION_TO_RETURNED_STATUS_END +} + +EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + // STEP.1 - Release the factory + delete static_cast(factory); + + // STEP.2 - Clean up cached kernel registries + onnxruntime::webgpu::CleanupKernelRegistries(); + + // STEP.3 - Clean up WebGPU contexts + onnxruntime::webgpu::CleanupWebGpuContexts(); + + // STEP.4 - Destroy the global default logger wrapper + ::onnxruntime::ep::adapter::LoggingManager::DestroyDefaultLogger(); + + // STEP.5 - Shutdown protobuf library + google::protobuf::ShutdownProtobufLibrary(); + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +} // extern "C" diff --git a/onnxruntime/core/providers/webgpu/ep/ep.cc b/onnxruntime/core/providers/webgpu/ep/ep.cc new file mode 100644 index 0000000000000..6beb62b5cf074 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/ep.cc @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep.h" + +#include "factory.h" + +#include "core/framework/run_options.h" +#include "core/framework/kernel_registry.h" +#include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/session/plugin_ep/ep_kernel_registration.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" + +#include "ep/get_capability_utils.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +using onnxruntime::ep::Api; + +// Constructor +Ep::Ep(std::unique_ptr impl, Factory& factory, const OrtLogger& logger, const Config& config) + : onnxruntime::ep::adapter::Ep{std::move(impl), config.cpu_allocator, config.device_allocator}, + factory_{factory}, + logger_{logger}, + config_{config} { + // Initialize the execution provider's function table + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + GetKernelRegistry = GetKernelRegistryImpl; + Compile = nullptr; // Per-kernel EP does not use Compile + ReleaseNodeComputeInfos = nullptr; + GetPreferredDataLayout = GetPreferredDataLayoutImpl; + ShouldConvertDataLayoutForOp = ShouldConvertDataLayoutForOpImpl; + SetDynamicOptions = nullptr; // Not implemented + OnRunStart = OnRunStartImpl; + OnRunEnd = OnRunEndImpl; + CreateAllocator = CreateAllocatorImpl; + CreateSyncStreamForDevice = nullptr; // Not stream aware + GetCompiledModelCompatibilityInfo = nullptr; // Not a compiled EP + IsConcurrentRunSupported = IsConcurrentRunSupportedImpl; +} + +// OrtEp interface implementations +const char* ORT_API_CALL Ep::GetNameImpl(const OrtEp* this_ptr) noexcept { + const auto* ep = static_cast(this_ptr); + return ep->factory_.GetName(&ep->factory_); +} + +OrtStatus* ORT_API_CALL Ep::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + + auto& ep = *static_cast(static_cast(this_ptr)->EpImpl()); + Ort::ConstGraph ort_graph{graph}; + + // Get all nodes in the graph + std::vector all_nodes = ort_graph.GetNodes(); + + if (all_nodes.empty()) { + return nullptr; // No nodes to process + } + + std::vector candidate_nodes; + std::vector tentative_candidate_nodes; + + // For each node, check if we have a registered kernel for it + for (const auto& node : all_nodes) { + std::string ep_name = node.GetEpName(); + + if (ep_name == kWebGpuExecutionProvider) { + candidate_nodes.push_back(node); + continue; + } + + // Reject nodes already assigned to a different (non-CPU) EP + if (!ep_name.empty() && ep_name != kCpuExecutionProvider) { + continue; + } + + const OrtKernelDef* kernel_def = nullptr; + RETURN_IF_ERROR(Api().ep.EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def)); + + if (kernel_def == nullptr) { + LOGS(ep.GetEpLogger(), INFO) << "webgpu kernel not found in registries for Op type: " + << node.GetOperatorType() << " node name: " << node.GetName(); + continue; + } + + auto cpu_node_names = ep.GetForceCpuNodeNames(); + if (std::find(cpu_node_names.begin(), + cpu_node_names.end(), + node.GetName()) != cpu_node_names.end()) { + LOGS(ep.GetEpLogger(), INFO) << "Force CPU execution for node: " << node.GetName(); + continue; + } + + // + // The following code checks if the node is really supported by webgpu EP + // + +#define FALLBACK_TO_CPU_IF_EXIST_INPUT(idx) \ + if (inputs.size() > idx && inputs[idx] != nullptr) { \ + continue; \ + } + +#define FALLBACK_TO_CPU_IF_EXIST_OUTPUT(idx) \ + if (outputs.size() > idx && outputs[idx] != nullptr) { \ + continue; \ + } + + // Check for Attention + if (node.GetOperatorType() == "Attention" && node.GetDomain() == kMSDomain) { + const auto& inputs = node.GetInputs(); + const auto& outputs = node.GetOutputs(); + + // Current implementation does not support mask_index(input[3]), past(input[4]) and past_seq_len(input[6]) + FALLBACK_TO_CPU_IF_EXIST_INPUT(3); + FALLBACK_TO_CPU_IF_EXIST_INPUT(4); + FALLBACK_TO_CPU_IF_EXIST_INPUT(6); + + // Current implementation does not support present(output[1]) + FALLBACK_TO_CPU_IF_EXIST_OUTPUT(1); + + // If attribute past_present_share_buffer is set, fallback to CPU + bool has_past_present_share_buffer = false; + for (const auto& attr : node.GetAttributes()) { + if (attr.GetName() == "past_present_share_buffer") { + int64_t val = 0; + RETURN_IF_ERROR(attr.GetValue(val)); + if (val != 0) { + has_past_present_share_buffer = true; + } + break; + } + } + if (has_past_present_share_buffer) { + continue; + } + } + + candidate_nodes.push_back(node); + tentative_candidate_nodes.push_back(node); + } + + std::unordered_set cpu_preferred_nodes; + RETURN_IF_ERROR(onnxruntime::ep::GetCpuPreferredNodes(*ort_graph, + *graph_support_info, + static_cast(this_ptr)->GetOrtLogger(), + tentative_candidate_nodes, + cpu_preferred_nodes)); + + for (const auto& node : candidate_nodes) { + if (cpu_preferred_nodes.count(node) == 0) { + RETURN_IF_ERROR(Api().ep.EpGraphSupportInfo_AddSingleNode(graph_support_info, node)); + } + } + + return nullptr; + + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::GetKernelRegistryImpl( + _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + + *kernel_registry = nullptr; + + // For the WebGPU EP, delegate to the CreateKernelRegistry function + // which properly constructs a registry using only public APIs + auto* ep = static_cast(this_ptr); + + auto& webgpu_ep = *static_cast(ep->EpImpl()); + + *kernel_registry = *webgpu_ep.GetKernelRegistryImpl(); + return nullptr; + + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::GetPreferredDataLayoutImpl(_In_ OrtEp* this_ptr, + _Out_ OrtEpDataLayout* preferred_data_layout) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + // Delegate to the underlying WebGPU EP's GetPreferredLayout() + // DataLayout enum values map 1:1 to OrtEpDataLayout (NCHW=0, NHWC=1) + auto* ep = static_cast(this_ptr); + *preferred_data_layout = static_cast(ep->EpImpl()->GetPreferredLayout()); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::ShouldConvertDataLayoutForOpImpl(_In_ OrtEp* this_ptr, + _In_z_ const char* domain, + _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + // DataLayout enum values map 1:1 to OrtEpDataLayout (NCHW=0, NHWC=1) + auto* ep = static_cast(this_ptr); + auto result = ep->EpImpl()->ShouldConvertDataLayoutForOp(domain, op_type, + static_cast(target_data_layout)); + if (result.has_value()) { + *should_convert = result.value() ? 1 : 0; + } else { + *should_convert = -1; + } + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::OnRunStartImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + onnxruntime::RunOptions options{}; + // currently only option "gpu_graph_id" is used + auto graph_annotation_str = Api().ort.GetRunConfigEntry(run_options, kOrtRunOptionsConfigCudaGraphAnnotation); + if (graph_annotation_str != nullptr) { + auto status = options.config_options.AddConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation, graph_annotation_str); + if (!status.IsOK()) { + return Api().ort.CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } + } + auto* ep = static_cast(this_ptr); + auto status = ep->EpImpl()->OnRunStart(options); + if (!status.IsOK()) { + return Api().ort.CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::OnRunEndImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* /*run_options*/, + _In_ bool sync_stream) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + auto* ep = static_cast(this_ptr); + auto status = ep->EpImpl()->OnRunEnd(sync_stream, {}); + if (!status.IsOK()) { + return Api().ort.CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::IsConcurrentRunSupportedImpl(_In_ OrtEp* /*this_ptr*/, _Out_ bool* is_concurrent_run_supported) noexcept { + *is_concurrent_run_supported = false; + return nullptr; +} + +OrtStatus* ORT_API_CALL Ep::CreateAllocatorImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + auto* ep = static_cast(this_ptr); + Ort::ConstMemoryInfo ort_memory_info{memory_info}; + if (ort_memory_info.GetAllocatorType() == OrtReadOnlyAllocator) { + *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, ep->config_.initializer_allocator); + } else { + *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, ep->config_.device_allocator); + } + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/ep/ep.h b/onnxruntime/core/providers/webgpu/ep/ep.h new file mode 100644 index 0000000000000..815623025f8a4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/ep.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/providers/webgpu/webgpu_execution_provider.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +class Factory; + +/// +/// A bridge class between the EP API and the WebGPU EP implementation. +/// +class Ep : public onnxruntime::ep::adapter::Ep { + public: + struct Config { + AllocatorPtr cpu_allocator; + AllocatorPtr device_allocator; + AllocatorPtr initializer_allocator; + }; + + Ep(std::unique_ptr impl, Factory& factory, const OrtLogger& logger, const Config& config); + + inline const OrtLogger& GetOrtLogger() const noexcept { + return logger_; + } + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + + static OrtStatus* ORT_API_CALL GetKernelRegistryImpl( + _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) noexcept; + + static OrtStatus* ORT_API_CALL GetPreferredDataLayoutImpl(_In_ OrtEp* this_ptr, + _Out_ OrtEpDataLayout* preferred_data_layout) noexcept; + + static OrtStatus* ORT_API_CALL ShouldConvertDataLayoutForOpImpl(_In_ OrtEp* this_ptr, + _In_z_ const char* domain, + _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert) noexcept; + + static OrtStatus* ORT_API_CALL OnRunStartImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options) noexcept; + + static OrtStatus* ORT_API_CALL OnRunEndImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options, + _In_ bool sync_stream) noexcept; + + static OrtStatus* ORT_API_CALL IsConcurrentRunSupportedImpl(_In_ OrtEp* this_ptr, + _Out_ bool* is_concurrent_run_supported) noexcept; + + Factory& factory_; + const OrtLogger& logger_; + Config config_{}; +}; + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/ep/factory.cc b/onnxruntime/core/providers/webgpu/ep/factory.cc new file mode 100644 index 0000000000000..99dd0c68f6954 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/factory.cc @@ -0,0 +1,206 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "factory.h" +#include "ep.h" + +#include "core/framework/error_code_helper.h" +#include "core/graph/constants.h" + +#include "core/framework/execution_provider.h" +#include "core/framework/config_options.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/allocator.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +using onnxruntime::ep::Api; + +// Constructor +Factory::Factory() : OrtEpFactory{}, + default_memory_info_{WEBGPU_BUFFER, OrtMemoryInfoDeviceType_GPU, + 0, // vendor id + 0, // device id + OrtDeviceMemoryType_DEFAULT, + 0, // alignment + OrtDeviceAllocator}, + readonly_memory_info_{WEBGPU_BUFFER, OrtMemoryInfoDeviceType_GPU, + 0, // vendor id + 0, // device id + OrtDeviceMemoryType_DEFAULT, + 0, // alignment + OrtReadOnlyAllocator} { + ort_version_supported = ORT_API_VERSION; + + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; +} + +// Static C API implementations + +const char* ORT_API_CALL Factory::GetNameImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return kWebGpuExecutionProvider; +} + +const char* ORT_API_CALL Factory::GetVendorImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return "Microsoft"; +} + +uint32_t ORT_API_CALL Factory::GetVendorIdImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return 0; +} + +const char* ORT_API_CALL Factory::GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return "0.1.0"; +} + +OrtStatus* ORT_API_CALL Factory::GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + auto factory = static_cast(this_ptr); + + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (Api().ort.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // TODO: any metadata or options to add? + OrtEpDevice* ep_device = nullptr; + ORT_API_RETURN_IF_ERROR(Api().ep.CreateEpDevice(this_ptr, + &device, nullptr, nullptr, + &ep_device)); + ORT_API_RETURN_IF_ERROR(Api().ep.EpDevice_AddAllocatorInfo(ep_device, factory->default_memory_info_)); + ORT_API_RETURN_IF_ERROR(Api().ep.EpDevice_AddAllocatorInfo(ep_device, factory->readonly_memory_info_)); + ep_devices[num_ep_devices++] = ep_device; + } + } + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Factory::CreateEpImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + if (num_devices == 0) { + return Api().ort.CreateStatus(ORT_INVALID_ARGUMENT, "No hardware devices provided to create WebGPU EP."); + } + + OrtKeyValuePairs* session_config_entries = nullptr; + ORT_API_RETURN_IF_ERROR(Api().ort.GetSessionOptionsConfigEntries(session_options, &session_config_entries)); + Ort::KeyValuePairs session_config_entries_holder(session_config_entries); // allow automatic release + + auto config_options = ConfigOptions{}; + const char* const* keys = nullptr; + const char* const* values = nullptr; + size_t num_entries = 0; + Api().ort.GetKeyValuePairs(session_config_entries, &keys, &values, &num_entries); + for (size_t i = 0; i < num_entries; ++i) { + auto status = config_options.AddConfigEntry(keys[i], values[i]); + if (!status.IsOK()) { + return Api().ort.CreateStatus((OrtErrorCode)status.Code(), status.ErrorMessage().c_str()); + } + } + + auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(config_options); + auto webgpu_ep = webgpu_ep_factory->CreateProvider(*session_options, *logger); + static_cast(webgpu_ep.get())->SetEpLogger(logger); + auto factory = static_cast(this_ptr); + const int context_id = webgpu_ep->GetDeviceId(); + Ep::Config webgpu_ep_config{ + CPUAllocator::DefaultInstance(), // CPU allocator + std::make_shared(WebGpuContextFactory::GetContext(context_id).BufferManager(), false), // default device allocator + std::make_shared(WebGpuContextFactory::GetContext(context_id).InitializerBufferManager(), true), // initializer device allocator + }; + *ep = new Ep(std::move(webgpu_ep), *factory, *logger, webgpu_ep_config); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +void ORT_API_CALL Factory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + delete static_cast(ep); +} + +OrtStatus* ORT_API_CALL Factory::CreateAllocatorImpl( + OrtEpFactory* /*this_ptr*/, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Ort::ConstMemoryInfo ort_memory_info{memory_info}; + + if (ort_memory_info.GetAllocatorType() != OrtDeviceAllocator || + ort_memory_info.GetDeviceId() != 0 || + ort_memory_info.GetAllocatorName() != WEBGPU_BUFFER) { + return Api().ort.CreateStatus(ORT_INVALID_ARGUMENT, + "Unsupported memory info for shared allocator."); + } + + *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, + [](const OrtMemoryInfo&) -> AllocatorPtr { + return std::make_shared(WebGpuContextFactory::DefaultContext().BufferManager(), false); + }); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +void ORT_API_CALL Factory::ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, OrtAllocator* allocator) noexcept { + onnxruntime::ep::adapter::Allocator* ptr = static_cast(allocator); + delete ptr; +} + +OrtStatus* ORT_API_CALL Factory::CreateDataTransferImpl( + OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + *data_transfer = OrtWebGpuCreateDataTransfer(); // TODO(fs-eire): pass context id if needed + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +bool ORT_API_CALL Factory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; // Default: not stream aware +} + +OrtStatus* ORT_API_CALL Factory::CreateSyncStreamForDeviceImpl( + OrtEpFactory* /*this_ptr*/, + const OrtMemoryDevice* /*memory_device*/, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** stream) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + *stream = nullptr; + return Api().ort.CreateStatus(ORT_NOT_IMPLEMENTED, + "CreateSyncStreamForDevice is not implemented for this EP factory."); + EXCEPTION_TO_RETURNED_STATUS_END +} + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/ep/factory.h b/onnxruntime/core/providers/webgpu/ep/factory.h new file mode 100644 index 0000000000000..f23b3871ebc60 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/factory.h @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "ep.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +/// +/// A bridge class between the EP API and the WebGPU EP Factory implementation. +/// +class Factory : public OrtEpFactory { + private: + // Static C API implementations + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* this_ptr, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl( + OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* this_ptr, OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl( + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept; + + Ort::MemoryInfo default_memory_info_; + Ort::MemoryInfo readonly_memory_info_; // used for initializers + + public: + Factory(); + ~Factory() = default; +}; + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/generator/range.cc b/onnxruntime/core/providers/webgpu/generator/range.cc index 3fa062f327ba2..a84660a020fed 100644 --- a/onnxruntime/core/providers/webgpu/generator/range.cc +++ b/onnxruntime/core/providers/webgpu/generator/range.cc @@ -79,36 +79,40 @@ template class Range; template class Range; template class Range; -void RegisterRangeKernels(KernelRegistry& kernel_registry, bool enable_int64) { - // Helper lambda to create kernel - auto create_range_kernel_info = [](auto type_tag) { - using T = decltype(type_tag); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { - out = std::make_unique>(info); - return Status::OK(); - }; - - return KernelCreateInfo( - KernelDefBuilder() - .SetName("Range") - .SetDomain(kOnnxDomain) - .SinceVersion(11) - .Provider(kWebGpuExecutionProvider) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .InputMemoryType(OrtMemTypeCPU, 0) - .InputMemoryType(OrtMemTypeCPU, 1) - .InputMemoryType(OrtMemTypeCPU, 2) - .Build(), - kernel_create_fn); - }; +namespace { +template +Status CreateRangeKernel(FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) { + out = std::make_unique>(info); + return Status::OK(); +} + +template +KernelCreateInfo CreateRangeKernelInfo() { + return KernelCreateInfo( + KernelDefBuilder() + .SetName("Range") + .SetDomain(kOnnxDomain) + .SinceVersion(11) + .Provider(kWebGpuExecutionProvider) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 0) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .Build(), + CreateRangeKernel); +} + +} // namespace + +void RegisterRangeKernels(KernelRegistry& kernel_registry, bool enable_int64) { // Always register float and int32_t - ORT_THROW_IF_ERROR(kernel_registry.Register(create_range_kernel_info(float{}))); - ORT_THROW_IF_ERROR(kernel_registry.Register(create_range_kernel_info(int32_t{}))); + ORT_THROW_IF_ERROR(kernel_registry.Register(CreateRangeKernelInfo())); + ORT_THROW_IF_ERROR(kernel_registry.Register(CreateRangeKernelInfo())); // Register int64_t only if int64 support is enabled if (enable_int64) { - ORT_THROW_IF_ERROR(kernel_registry.Register(create_range_kernel_info(int64_t{}))); + ORT_THROW_IF_ERROR(kernel_registry.Register(CreateRangeKernelInfo())); } } diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 6bb0f688bfdb7..2695c2800d37a 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -90,7 +90,7 @@ template KernelCreateInfo CreateCastKernelInfo(bool enable_int64) { const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 75453b991a0cd..c6178d44dba75 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -98,7 +98,7 @@ Status Concat::ComputeInternal(ComputeContext& context) const { } Prepare prepare; - ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), input_tensors, prepare)); + ORT_RETURN_IF_ERROR(PrepareForComputeImpl(&context.KernelContext(), input_tensors, prepare)); if (prepare.output_num_elements == 0) { return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 0dacd589cbba8..0d39b1ec9d35e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -108,7 +108,7 @@ template KernelCreateInfo CreateExpandVersionedKernelInfo(bool enable_int64) { const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; @@ -129,7 +129,7 @@ template KernelCreateInfo CreateExpandKernelInfo(bool enable_int64) { const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc index b3e5c7b4e8310..970c2d6bed7a3 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -60,7 +60,7 @@ Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { Status Gather::ComputeInternal(ComputeContext& context) const { Prepare p; - ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); + ORT_RETURN_IF_ERROR(PrepareForComputeImpl(&context.KernelContext(), p)); uint32_t data_size = onnxruntime::narrow(p.output_tensor->Shape().Size()); if (data_size == 0) { return Status::OK(); diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.cc b/onnxruntime/core/providers/webgpu/tensor/pad.cc index 0e77ec46bbddb..7a576c4b53ecf 100644 --- a/onnxruntime/core/providers/webgpu/tensor/pad.cc +++ b/onnxruntime/core/providers/webgpu/tensor/pad.cc @@ -49,7 +49,7 @@ Status Pad::ComputeInternal(ComputeContext& context) const { const auto pads_data = pads_tensor->DataAsSpan(); // Compute Pads by applying axes if specified otherwise copy the supplied pads. - PadBase::ComputePads(context.KernelContext(), data_rank, pads_data, pads); + PadBase::ComputePadsImpl(context.KernelContext(), data_rank, pads_data, pads); // Separate out any negative pads into the slices array PadBase::SeparateNegativeToSlices(pads, slices); diff --git a/onnxruntime/core/providers/webgpu/tensor/shape_op.cc b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc index b211d48dab1c9..09194aa9f4dbb 100644 --- a/onnxruntime/core/providers/webgpu/tensor/shape_op.cc +++ b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc @@ -3,11 +3,73 @@ #include "core/providers/webgpu/webgpu_kernel.h" #include "core/providers/webgpu/webgpu_supported_types.h" -#include "core/providers/cpu/tensor/shape_op.h" namespace onnxruntime { namespace webgpu { +#ifndef SHARED_PROVIDER +#include "core/common/common.h" +#include "core/common/narrow.h" +#include "core/framework/op_kernel.h" +#endif + +#include +#include + +class Shape final : public OpKernel { + public: + Shape(const OpKernelInfo& info) : OpKernel(info) { + info.GetAttrOrDefault("start", &start_index_, 0); + + if (start_index_ != 0) { + // "start" is provided and is non-default (default is 0) + needs_slicing_ = true; + } + + if (info.GetAttr("end", &end_index_).IsOK()) { + needs_slicing_ = true; + } + } + + // Takes a tensor as input and outputs an 1D int64 tensor + // containing the shape of the input tensor. + Status Compute(OpKernelContext* context) const override { + const auto* input = context->Input(0); + const TensorShape& input_shape = input->Shape(); + + int64_t rank = gsl::narrow_cast(input_shape.NumDimensions()); + + if (!needs_slicing_) { // vanilla use of Shape (no slicing) + Tensor* output = context->Output(0, {rank}); + input_shape.CopyDims(output->MutableData(), static_cast(rank)); + } else { // slicing is needed + int64_t true_start = start_index_; + int64_t true_end = end_index_; + + // Deal with negative(s) and clamp + true_start = true_start < 0 ? true_start + rank : true_start; + true_start = true_start < 0 ? 0 : ((true_start > rank) ? rank : true_start); + + true_end = true_end < 0 ? true_end + rank : true_end; + true_end = true_end < 0 ? 0 : ((true_end > rank) ? rank : true_end); + + auto slice_length = true_end - true_start; + Tensor* output = context->Output(0, {slice_length < 0 ? 0 : slice_length}); + + if (slice_length > 0) { + input_shape.CopyDims(output->MutableData(), onnxruntime::narrow(true_start), onnxruntime::narrow(slice_length)); + } + } + + return Status::OK(); + } + + private: + bool needs_slicing_ = false; + int64_t start_index_ = 0; + int64_t end_index_ = std::numeric_limits::max(); +}; + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Shape, kOnnxDomain, diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc index 104fcf1812af8..3337448564bed 100644 --- a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc @@ -12,7 +12,7 @@ template KernelCreateInfo CreateUnsqueezeVersionedKernelInfo(bool enable_int64) { const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; @@ -47,7 +47,7 @@ template KernelCreateInfo CreateUnsqueezeKernelInfo(bool enable_int64) { const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; diff --git a/onnxruntime/core/providers/webgpu/tensor/upsample.cc b/onnxruntime/core/providers/webgpu/tensor/upsample.cc index fb406883ba4ba..8f51ed45004bf 100644 --- a/onnxruntime/core/providers/webgpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/webgpu/tensor/upsample.cc @@ -90,7 +90,7 @@ Status Upsample::ComputeInternal(ComputeContext& context) const { InlinedVector scales_array(input_dims.size()); // opset < 10 - if (OpKernel::Node().InputDefs().size() == 1) { + if (OpKernel::Node().SinceVersion() < 10) { scales_array = scales_; // Compute output shape from scales attributes and input dims ComputeOutputShape(scales_array, input_dims, output_dims); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index c6255e6f352d9..b4d751ce3a2c0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -3,6 +3,7 @@ #include "core/providers/webgpu/webgpu_execution_provider.h" +#include #include #include #include @@ -445,8 +446,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ScatterElements); -std::unique_ptr RegisterKernels(bool enable_graph_capture = false, bool enable_int64 = false) { - auto kernel_registry = std::make_unique(); +std::unique_ptr RegisterKernels(bool enable_graph_capture, bool enable_int64) { + auto kernel_registry = std::make_unique(); static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list becoming empty after ops-reducing @@ -837,6 +838,72 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals return kernel_registry; } +#if defined(ORT_USE_EP_API_ADAPTERS) + +namespace { +std::mutex g_kernel_registry_mutex; +std::shared_ptr g_kernel_registry; +std::shared_ptr g_graph_capture_kernel_registry; +std::shared_ptr g_int64_kernel_registry; +} // namespace + +void CleanupKernelRegistries() { + std::lock_guard lock(g_kernel_registry_mutex); + g_kernel_registry.reset(); + g_graph_capture_kernel_registry.reset(); + g_int64_kernel_registry.reset(); +} +#endif + +std::shared_ptr GetKernelRegistry(bool enable_graph_capture, bool enable_int64) { + // kernel registry variables are defined differently based on build configuration + // + // - When building as a static library, use static local variable. This is because + // we don't have a reliable way to explicitly destroy the kernel registry after + // use. + // + // - When building as a shared library, use global variables. The cleanup will be performed + // when `ReleaseEpFactory` is called. + // + // Graph capture mode needs a separate kernel registry because contrib kernel registration + // differs based on enable_graph_capture, and enable_int64 is always true when + // enable_graph_capture is true. + if (enable_graph_capture) { +#if !defined(ORT_USE_EP_API_ADAPTERS) + static std::shared_ptr registry = RegisterKernels(true, true); + return registry; +#else + std::lock_guard lock(g_kernel_registry_mutex); + if (g_graph_capture_kernel_registry == nullptr) { + g_graph_capture_kernel_registry = RegisterKernels(true, true); + } + return g_graph_capture_kernel_registry; +#endif + } else if (enable_int64) { +#if defined(ORT_USE_EP_API_ADAPTERS) + std::lock_guard lock(g_kernel_registry_mutex); + if (g_int64_kernel_registry == nullptr) { + g_int64_kernel_registry = RegisterKernels(false, true); + } + return g_int64_kernel_registry; +#else + static std::shared_ptr registry = RegisterKernels(false, true); + return registry; +#endif + } else { +#if defined(ORT_USE_EP_API_ADAPTERS) + std::lock_guard lock(g_kernel_registry_mutex); + if (g_kernel_registry == nullptr) { + g_kernel_registry = RegisterKernels(false, false); + } + return g_kernel_registry; +#else + static std::shared_ptr registry = RegisterKernels(false, false); + return registry; +#endif + } +} + } // namespace webgpu using namespace webgpu; @@ -850,6 +917,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, preferred_data_layout_{config.data_layout}, force_cpu_node_names_{std::move(config.force_cpu_node_names)}, enable_graph_capture_{config.enable_graph_capture}, + // enable_int64_ is always true when enable_graph_capture_ is true enable_int64_{config.enable_graph_capture || config.enable_int64}, multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset}, prepack_allocator_{std::make_shared(context_.InitializerBufferManager(), false)} { @@ -882,6 +950,7 @@ std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { }; } +#if !defined(ORT_USE_EP_API_ADAPTERS) std::vector> WebGpuExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, @@ -973,20 +1042,7 @@ std::vector> WebGpuExecutionProvider::GetCapa return result; } -std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() const { - // Cache registries based on enable_graph_capture_ and enable_int64_ flags - // Note: enable_int64_ is always true when enable_graph_capture_ is true - if (enable_graph_capture_) { - static std::shared_ptr registry = webgpu::RegisterKernels(true, true); - return registry; - } else if (enable_int64_) { - static std::shared_ptr registry = webgpu::RegisterKernels(false, true); - return registry; - } else { - static std::shared_ptr registry = webgpu::RegisterKernels(false, false); - return registry; - } -} +#endif // !defined(ORT_USE_EP_API_ADAPTERS) std::unique_ptr WebGpuExecutionProvider::GetDataTransfer() const { return std::make_unique(BufferManager()); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index b5a6b5f167faf..b46d3f3cb45d2 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -4,6 +4,11 @@ #pragma once +#include +#include +#include +#include + #include "core/framework/execution_provider.h" #include "core/framework/session_options.h" #include "core/graph/constants.h" @@ -28,6 +33,9 @@ class GpuBufferAllocator; // Forward declare CapturedCommandInfo which is now defined in webgpu_context.h struct CapturedCommandInfo; + +// The actual implementation of kernel registration. +std::shared_ptr GetKernelRegistry(bool enable_graph_capture, bool enable_int64); } // namespace webgpu struct WebGpuExecutionProviderConfig { @@ -44,13 +52,21 @@ class WebGpuExecutionProvider : public IExecutionProvider { WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& config); ~WebGpuExecutionProvider() override; + inline auto GetKernelRegistryImpl() const { + return webgpu::GetKernelRegistry(enable_graph_capture_, enable_int64_); + } + +#if !defined(ORT_USE_EP_API_ADAPTERS) std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; - std::shared_ptr GetKernelRegistry() const override; + std::shared_ptr GetKernelRegistry() const override { + return GetKernelRegistryImpl(); + } +#endif std::unique_ptr GetDataTransfer() const override; #if defined(__wasm__) std::unique_ptr GetExternalDataLoader() const override; @@ -83,8 +99,18 @@ class WebGpuExecutionProvider : public IExecutionProvider { Status ReplayGraph(int graph_annotation_id) override; webgpu::BufferManager& BufferManager() const; AllocatorPtr PrepackAllocator() const { return prepack_allocator_; } + std::span GetForceCpuNodeNames() const { return force_cpu_node_names_; } uint32_t MultiRotaryCacheConcatOffset() const { return multi_rotary_cache_concat_offset_; } +#if defined(ORT_USE_EP_API_ADAPTERS) + inline onnxruntime::ep::adapter::Logger& GetEpLogger() const { + return *ep_logger_; + } + inline void SetEpLogger(const OrtLogger* logger) { + ep_logger_ = std::make_unique(logger); + } +#endif + private: bool IsGraphCaptureAllowed() const; void IncrementRegularRunCountBeforeGraphCapture(); @@ -114,6 +140,10 @@ class WebGpuExecutionProvider : public IExecutionProvider { // Allocator for prepacked weights (uses buffers without mapping) AllocatorPtr prepack_allocator_; + +#if defined(ORT_USE_EP_API_ADAPTERS) + std::unique_ptr ep_logger_; +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index fc2496f0c7b68..16899370e47f1 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -284,11 +284,11 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( // WebGPU DataTransfer implementation wrapper for the C API with lazy initialization struct WebGpuDataTransferImpl : OrtDataTransferImpl { - WebGpuDataTransferImpl(const OrtApi& ort_api_in) + WebGpuDataTransferImpl(const OrtApi& ort_api_in, int context_id) : ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()}, data_transfer_{nullptr}, - context_id_{0}, // Always use context 0 for Environment's data transfer + context_id_{context_id}, init_mutex_{} { ort_version_supported = ORT_API_VERSION; CanCopy = CanCopyImpl; // OrtDataTransferImpl::CanCopy callback @@ -327,9 +327,9 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { // If both are GPU, they must have the same device ID if (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) { - uint64_t src_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device); - uint64_t dst_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device); - if (src_device_id != dst_device_id) { + int src_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device); + int dst_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device); + if (src_device_id != impl.context_id_ || dst_device_id != impl.context_id_) { return false; // Cannot copy between different devices } } @@ -362,19 +362,40 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { auto& context = WebGpuContextFactory::DefaultContext(); - // Create the DataTransfer instance - // Note: The DataTransfer holds a const reference to BufferManager. The BufferManager's lifecycle + // Create the DataTransferImpl instance + // Note: The DataTransferImpl holds a const reference to BufferManager. The BufferManager's lifecycle // is managed by the WebGpuContext, which is stored in a static WebGpuContextFactory and persists // for the lifetime of the application, ensuring the reference remains valid. - impl.data_transfer_ = std::make_unique(context.BufferManager()); + impl.data_transfer_ = std::make_unique(context.BufferManager()); } } // Now perform the actual tensor copy for (size_t idx = 0; idx < num_tensors; ++idx) { - const OrtValue* src_tensor = src_tensors[idx]; - OrtValue* dst_tensor = dst_tensors[idx]; - auto status = impl.data_transfer_->CopyTensor(src_tensor->Get(), *dst_tensor->GetMutable()); +#if defined(ORT_USE_EP_API_ADAPTERS) + Ort::ConstValue src_value{src_tensors[idx]}; + const void* src_data = src_value.GetTensorRawData(); + size_t size = src_value.GetTensorSizeInBytes(); + bool src_is_gpu = src_value.GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU; + + Ort::UnownedValue dst_value{dst_tensors[idx]}; + void* dst_data = dst_value.GetTensorMutableRawData(); + bool dst_is_gpu = dst_value.GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU; +#else + const Tensor& src_tensor = src_tensors[idx]->Get(); + const void* src_data = src_tensor.DataRaw(); + size_t size = src_tensor.SizeInBytes(); + bool src_is_gpu = src_tensor.Location().device.Type() == OrtDevice::GPU; + + Tensor& dst_tensor = *dst_tensors[idx]->GetMutable(); + void* dst_data = dst_tensor.MutableDataRaw(); + bool dst_is_gpu = dst_tensor.Location().device.Type() == OrtDevice::GPU; +#endif + auto status = impl.data_transfer_->CopyTensor(src_data, + src_is_gpu, + dst_data, + dst_is_gpu, + size); if (!status.IsOK()) { return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str()); } @@ -398,19 +419,23 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { const OrtApi& ort_api; const OrtEpApi& ep_api; - std::unique_ptr data_transfer_; // Lazy-initialized - int context_id_; // Track which context we're using - std::mutex init_mutex_; // Protects lazy initialization + std::unique_ptr data_transfer_; // Lazy-initialized + int context_id_; // Track which context we're using + std::mutex init_mutex_; // Protects lazy initialization }; -OrtDataTransferImpl* OrtWebGpuCreateDataTransfer() { +OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id /* = 0 */) { +#if defined(ORT_USE_EP_API_ADAPTERS) + return new WebGpuDataTransferImpl(onnxruntime::ep::Api().ort, context_id); +#else // Validate API version is supported const OrtApi* api = OrtApis::GetApi(ORT_API_VERSION); if (!api) { // API version not supported - return nullptr to indicate failure return nullptr; } - return new WebGpuDataTransferImpl(*api); + return new WebGpuDataTransferImpl(*api, context_id); +#endif } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h index 021e33ef25309..876a2e11d791a 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h @@ -22,6 +22,6 @@ struct WebGpuProviderFactoryCreator { // C API to create data transfer for WebGPU EP with lazy initialization // Context will be determined from tensors during the first CopyTensors call // Caller takes ownership of the returned OrtDataTransferImpl* -OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(); +OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id = 0); } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 00ea8947b9dd7..06ff495327ecd 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -451,7 +451,7 @@ TEST(MatMul2Bits, Float32_2b_Accuracy4) { TestMatMul2BitsTyped(); } -#ifdef USE_WEBGPU +#if defined(USE_WEBGPU) && !defined(ORT_USE_EP_API_ADAPTERS) namespace { @@ -594,7 +594,7 @@ TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_DP4A) { RunTest2Bits(opts); } -#endif // USE_WEBGPU +#endif // defined(USE_WEBGPU) && !defined(ORT_USE_EP_API_ADAPTERS) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/unittest_util/base_tester.cc b/onnxruntime/test/unittest_util/base_tester.cc index 18ead92ce3f18..1f744df14cfb8 100644 --- a/onnxruntime/test/unittest_util/base_tester.cc +++ b/onnxruntime/test/unittest_util/base_tester.cc @@ -749,10 +749,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultXnnpackExecutionProvider(); else if (provider_type == onnxruntime::kDmlExecutionProvider) execution_provider = DefaultDmlExecutionProvider(); -#if !defined(USE_WEBGPU) || !defined(ORT_USE_EP_API_ADAPTERS) else if (provider_type == onnxruntime::kWebGpuExecutionProvider) execution_provider = DefaultWebGpuExecutionProvider(); -#endif else if (provider_type == dynamic_plugin_ep_name) { execution_provider = dynamic_plugin_ep_infra::MakeEp(); } diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc index fd2cf2f712628..1f82e1f893eab 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc @@ -11,6 +11,7 @@ #include "nlohmann/json.hpp" #include "core/common/common.h" +#include "core/framework/config_options.h" #include "core/framework/execution_provider.h" #include "core/session/abi_session_options_impl.h" #include "core/session/ort_env.h" @@ -167,7 +168,7 @@ void Shutdown() { g_plugin_ep_infrastructure_state.reset(); } -std::unique_ptr MakeEp(const logging::Logger* logger) { +std::unique_ptr MakeEp(const logging::Logger* logger, const ConfigOptions* ep_options) { if (!IsInitialized()) { return nullptr; } @@ -182,6 +183,13 @@ std::unique_ptr MakeEp(const logging::Logger* logger) { StrMapToKeyValueCstrVectors(state.config.default_ep_options, default_ep_option_key_cstrs, default_ep_option_value_cstrs); + if (ep_options != nullptr) { + for (const auto& [key, value] : ep_options->configurations) { + default_ep_option_key_cstrs.push_back(key.c_str()); + default_ep_option_value_cstrs.push_back(value.c_str()); + } + } + OrtSessionOptions ort_session_options{}; ORT_THROW_IF_ERROR(AddEpOptionsToSessionOptions(state.selected_c_ep_devices, default_ep_option_key_cstrs, diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h index 680045be9330c..0962df8e35308 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h @@ -17,6 +17,7 @@ namespace onnxruntime { struct IExecutionProviderFactory; class IExecutionProvider; +struct ConfigOptions; namespace logging { class Logger; @@ -74,7 +75,8 @@ bool IsInitialized(); void Shutdown(); // Returns a dynamic plugin EP `IExecutionProvider` instance, or `nullptr` if uninitialized. -std::unique_ptr MakeEp(const logging::Logger* logger = nullptr); +// `ep_options` provides additional EP-specific option overrides (key-value pairs) on top of the defaults. +std::unique_ptr MakeEp(const logging::Logger* logger = nullptr, const ConfigOptions* ep_options = nullptr); // Gets the dynamic plugin EP name, or `std::nullopt` if uninitialized. std::optional GetEpName(); diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 6dc38f84c79d5..7e6bc6ae06020 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -14,8 +14,13 @@ #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" #endif +#if defined(USE_WEBGPU) +#include "core/graph/constants.h" +#include "core/session/abi_session_options_impl.h" +#endif #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/providers.h" +#include "test/unittest_util/test_dynamic_plugin_ep.h" namespace onnxruntime { @@ -273,19 +278,37 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { } std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) { -#if defined(USE_WEBGPU) && !defined(ORT_USE_EP_API_ADAPTERS) +#if defined(USE_WEBGPU) ConfigOptions config_options{}; + + // Helper to strip the EP prefix from config entry keys when building as a plugin EP. + // The full key is like "ep.webgpuexecutionprovider.storageBufferCacheMode", and the + // config entry expects just "storageBufferCacheMode" in the EP API build. + // Returns a pointer into the original string, so the result is valid as long as the input is. + auto strip_ep_prefix = [](const char* full_key) -> const char* { +#if defined(ORT_USE_EP_API_ADAPTERS) + std::string_view key{full_key}; + std::string_view prefix = OrtSessionOptions::GetProviderOptionPrefix(kWebGpuExecutionProvider); + ORT_ENFORCE(key.length() >= prefix.length() && key.substr(0, prefix.length()) == prefix, + "Config key \"", key, "\" does not start with expected prefix \"", prefix, "\""); + return full_key + prefix.length(); +#else + return full_key; +#endif + }; + // Disable storage buffer cache - ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode, + ORT_ENFORCE(config_options.AddConfigEntry(strip_ep_prefix(webgpu::options::kStorageBufferCacheMode), webgpu::options::kBufferCacheMode_Disabled) .IsOK()); if (!is_nhwc) { // Enable NCHW support - ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kPreferredLayout, + ORT_ENFORCE(config_options.AddConfigEntry(strip_ep_prefix(webgpu::options::kPreferredLayout), webgpu::options::kPreferredLayout_NCHW) .IsOK()); } - return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); + + return WebGpuExecutionProviderWithOptions(config_options); #else ORT_UNUSED_PARAMETER(is_nhwc); return nullptr; @@ -293,8 +316,16 @@ std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) } std::unique_ptr WebGpuExecutionProviderWithOptions(const ConfigOptions& config_options) { -#if defined(USE_WEBGPU) && !defined(ORT_USE_EP_API_ADAPTERS) +#if defined(USE_WEBGPU) +#if defined(ORT_USE_EP_API_ADAPTERS) + auto ep_name = dynamic_plugin_ep_infra::GetEpName(); + ORT_ENFORCE(ep_name == kWebGpuExecutionProvider, + "Dynamic plugin EP is not the WebGPU EP. Expected \"", kWebGpuExecutionProvider, + "\", got \"", ep_name.value_or(""), "\""); + return dynamic_plugin_ep_infra::MakeEp(nullptr, &config_options); +#else return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); +#endif #else ORT_UNUSED_PARAMETER(config_options); return nullptr; From 142eccab71312c368034bcb610d0105d19b28486 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 24 Mar 2026 15:13:21 -0400 Subject: [PATCH 15/17] [WebGPU] Einsum fixes for 5D tensors (#27779) ### Description Fixes WebGPU Einsum op by replacing the manual `uniforms.inputN_shape[idx]` string with `GetElementAt(...)` which correctly handles uniform shape access for all tensor ranks. I also added a bunch of tests for this... ### Motivation and Context Closes https://github.com/microsoft/onnxruntime/issues/27762 --- .../core/providers/webgpu/math/einsum.cc | 7 +- .../test/providers/cpu/math/einsum_test.cc | 100 ++++++++++++++++++ 2 files changed, 105 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/einsum.cc b/onnxruntime/core/providers/webgpu/math/einsum.cc index bce173b1c62de..e17c0281c738f 100644 --- a/onnxruntime/core/providers/webgpu/math/einsum.cc +++ b/onnxruntime/core/providers/webgpu/math/einsum.cc @@ -325,9 +325,12 @@ Status EinsumProgram::GenerateShaderCode(ShaderHelper& shader) const { // Generate a WGSL loop header for reduction over this dimension // Format like: for(var j: u32 = 0; j < uniforms.input0_shape[1]; j++) {, given equation // "ij,jk->ik". + std::string shape_access = GetElementAt( + "uniforms.input" + std::to_string(lhs_term_index) + "_shape", + input_index, + static_cast(inputs[lhs_term_index].get().Rank())); reduce_ops_loop_headers.push_back("for(var " + symbol + ": u32 = 0; " + symbol + " < " + - "uniforms.input" + std::to_string(lhs_term_index) + - "_shape[" + std::to_string(input_index) + "]; " + + shape_access + "; " + symbol + "++) {"); // Add corresponding loop closing brace diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index 2bf62b6944735..f732103842146 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -2171,5 +2171,105 @@ TEST_P(EinsumTransposeMatMulThreeInputsTest, EinsumTransposeMatMulThreeInputsTes INSTANTIATE_TEST_SUITE_P(EinsumTransposeMatMulThreeInputsTests, EinsumTransposeMatMulThreeInputsTest, testing::ValuesIn(case1)); +// Theme: High-rank contractions (WebGPU shader generation regression tests) + +// 5D contraction (Mamba-style chunked SSM state computation) +TEST(Einsum, ExplicitEinsumAs5DContraction) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "bcknd,bckns->bcnds"); + test.AddInput("x", {1, 1, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + test.AddInput("y", {1, 1, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + test.AddOutput("o", {1, 1, 2, 2, 2}, + {26.f, 32.f, 32.f, 40.f, 58.f, 68.f, 68.f, 80.f}); + test.Run(); +} + +// 5D x 5D contraction (contract middle dims, keep outer + inner) +TEST(Einsum, ExplicitEinsumAs5DContraction_abcde_abcdf) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abcde,abcdf->abef"); + test.AddInput("x", {1, 1, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + test.AddInput("y", {1, 1, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + test.AddOutput("o", {1, 1, 2, 2}, + {84.f, 100.f, 100.f, 120.f}); + test.Run(); +} + +// 5D x 5D contraction (contract 3 trailing dims) +TEST(Einsum, ExplicitEinsumAs5DContraction_abcde_afcde) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abcde,afcde->abf"); + test.AddInput("x", {1, 2, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f}); + test.AddInput("y", {1, 2, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f}); + test.AddOutput("o", {1, 2, 2}, + {204.f, 492.f, 492.f, 1292.f}); + test.Run(); +} + +// 5D reduction (reduce 2 of 5 axes) +TEST(Einsum, ExplicitEinsumAs5DReduction) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abcde->ace"); + test.AddInput("x", {2, 2, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}); + test.AddOutput("o", {2, 2, 2}, + {24.f, 28.f, 40.f, 44.f, 88.f, 92.f, 104.f, 108.f}); + test.Run(); +} + +// 6D x 6D contraction +TEST(Einsum, ExplicitEinsumAs6DContraction) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abcdef,abcdeg->abcfg"); + test.AddInput("x", {1, 1, 1, 1, 2, 2}, + {1.f, 2.f, 3.f, 4.f}); + test.AddInput("y", {1, 1, 1, 1, 2, 2}, + {1.f, 2.f, 3.f, 4.f}); + test.AddOutput("o", {1, 1, 1, 2, 2}, + {10.f, 14.f, 14.f, 20.f}); + test.Run(); +} + +// 6D reduction (reduce 3 of 6 axes) +TEST(Einsum, ExplicitEinsumAs6DReduction) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abcdef->adf"); + test.AddInput("x", {2, 2, 2, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, + 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, + 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f}); + test.AddOutput("o", {2, 2, 2}, + {112.f, 120.f, 144.f, 152.f, 368.f, 376.f, 400.f, 408.f}); + test.Run(); +} + +// 3-input bilinear form (x^T A y reduced to scalar) +TEST(Einsum, ExplicitEinsumAsBilinearFormToScalar) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "i,ij,j->"); + test.AddInput("x", {3}, {1.f, 2.f, 3.f}); + test.AddInput("y", {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + test.AddInput("z", {4}, {1.f, 2.f, 3.f, 4.f}); + test.AddOutput("o", {}, {500.f}); + test.Run(); +} + } // namespace test } // namespace onnxruntime From 2f668789d9cac8fd708867f820ed0d606de1e361 Mon Sep 17 00:00:00 2001 From: Nico Martin Date: Tue, 24 Mar 2026 20:41:58 +0100 Subject: [PATCH 16/17] Fix WebGPU device destroyed on session release, breaking session recreation (#27634) ## Description We had a weird behavior in Transformers.js V4. After calling `InferenceSession.release()` on a WebGPU session, attempting to create a new WebGPU session fails with: ``` WebGPU device lost (2): Device was destroyed. ``` In Transformers.js we encourage the use of the `create -> release -> create` pattern, because we expect the application to run for some time and might use multiple models. So it makes sense to unload models after the job is done. It seems like this was introduced in [e03631ee528](https://github.com/microsoft/onnxruntime/commit/e03631ee528), which added the `preserveDevice` option with a default value of `false`. When the last session is released and `preserveDevice=false`, the C++ side destroys the WebGPU device, but the JavaScript reference in `env.webgpu.device` is never cleared, leaving a stale reference to a destroyed device. ## Changes **Clear stale device reference when lost** (`backend-webgpu.ts`) 1. Made device property `configurable: true` to allow deletion 2. Added cleanup logic in `dispose()` to detect device loss via `device.lost` promise 3. When device is lost (destroyed, driver crash, etc.), delete the stale `env.webgpu.device` reference This allows subsequent session creation to acquire a fresh device instead of attempting to reuse a lost one. --- js/web/lib/wasm/jsep/backend-webgpu.ts | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e486e4b0e043d..eeff82484ef5f 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -278,7 +278,7 @@ export class WebGpuBackend { value: this.device, writable: false, enumerable: true, - configurable: false, + configurable: true, // Allow deletion when device is destroyed }); Object.defineProperty(this.env.webgpu, 'adapter', { value: adapter, @@ -296,6 +296,14 @@ export class WebGpuBackend { this.querySet.destroy(); } this.gpuDataManager.dispose(); + + // Clear the device reference when it's lost to allow new sessions to create a fresh device + // This handles the case where preserve_device=false (default) causes the C++ side to destroy the device + if (this.device && this.env?.webgpu) { + void this.device.lost.then(() => { + delete (this.env.webgpu as unknown as Record).device; + }); + } } getCommandEncoder(): GPUCommandEncoder { From 1b982ddcb2293575e283b4a28d1d791e8842dc54 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 24 Mar 2026 13:56:59 -0700 Subject: [PATCH 17/17] [CPU] Handle ONNX domain Gelu and HardSigmoid activations in the NCHWc transformer suite (#27821) ### Description As title ### Motivation and Context Tiny continuation to https://github.com/microsoft/onnxruntime/pull/27691 --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../core/optimizer/nchwc_transformer.cc | 13 ++- .../test/optimizer/nchwc_optimizer_test.cc | 85 +++++++++++++++++-- 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 4e03450077718..b9366ff0abae8 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -881,12 +881,21 @@ void NchwcTransformerImpl::TransformActivation(Node& node) { const bool can_fuse_activation = (node.OpType() == "Relu") || (node.OpType() == "Sigmoid") || - (node.OpType() == "Tanh"); + (node.OpType() == "Tanh") || + (node.OpType() == "HardSigmoid"); if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) && can_fuse_activation && (nchwc_input->starting_original_uses_ == 1) && (graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) { nchwc_node.AddAttribute("activation", node.OpType()); + if (node.OpType() == "HardSigmoid") { + const auto* alpha_attr = graph_utils::GetNodeAttribute(node, "alpha"); + const auto* beta_attr = graph_utils::GetNodeAttribute(node, "beta"); + InlinedVector activation_params{ + alpha_attr == nullptr ? 0.2f : alpha_attr->f(), + beta_attr == nullptr ? 0.5f : beta_attr->f()}; + nchwc_node.AddAttribute("activation_params", activation_params); + } FuseNchwcArgument(node, *nchwc_input); removed_nodes_.push_front(node.Index()); } else { @@ -1265,8 +1274,10 @@ void NchwcTransformerImpl::Transform(Node& node) { } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) { TransformConcat(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6, 22}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {20}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {1}, kMSDomain) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain)) { TransformActivation(node); diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 6078660bf0d6e..cd210f7bc70ba 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -1407,7 +1407,9 @@ TEST(NchwcOptimizerTests, UpsampleLinear) { } TEST(NchwcOptimizerTests, Activation) { - auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) { + auto test_case = [&](const std::string& activation_op_type, + const std::string& domain = kOnnxDomain, + int opset_version = 13) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({1, 48, 11, 15}); auto* conv1_output_arg = helper.MakeIntermediate(); @@ -1431,23 +1433,93 @@ TEST(NchwcOptimizerTests, Activation) { EXPECT_EQ(op_to_count["Add"], 1); }; - NchwcOptimizerTester(build_test_case, check_nchwc_graph); + NchwcOptimizerTester(build_test_case, check_nchwc_graph, opset_version); }; // Verify that the optimizer doesn't add reorders for these activations in - // this pattern. Relu/Sigmoid/Tanh are generally fusable with a + // this pattern. Relu/Sigmoid/Tanh/HardSigmoid are generally fusable with a // preceding convolution, but not here because the Conv output is consumed // both by the activation node and directly by the Add node. Gelu/QuickGelu // are also expected to remain as separate nodes. test_case("Relu"); test_case("Sigmoid"); test_case("Tanh"); + test_case("HardSigmoid"); + test_case("Gelu", kOnnxDomain, 20); test_case("Gelu", kMSDomain); test_case("QuickGelu", kMSDomain); } -TEST(NchwcOptimizerTests, ActivationSingleConsumerConvGuard) { - auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) { +TEST(NchwcOptimizerTests, ActivationSingleConsumerConvFusion) { + constexpr float kHardSigmoidAlpha = 0.125f; + constexpr float kHardSigmoidBeta = 0.625f; + + auto test_case = [&](const std::string& activation_op_type) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 48, 11, 15}); + auto* conv1_output_arg = helper.MakeIntermediate(); + auto* activation_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3}); + auto& activation_node = helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}); + if (activation_op_type == "HardSigmoid") { + activation_node.AddAttribute("alpha", kHardSigmoidAlpha); + activation_node.AddAttribute("beta", kHardSigmoidBeta); + } + helper.AddConvNode(activation_output_arg, output_arg, {16, 32, 1, 1}); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto& graph = session.GetGraph(); + auto op_to_count = CountOpsInGraph(graph); + + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count[activation_op_type], 0); + + size_t fused_conv_count = 0; + for (const auto& node : graph.Nodes()) { + if (node.OpType() != "Conv" || node.Domain() != kMSNchwcDomain) { + continue; + } + + const auto& attributes = node.GetAttributes(); + auto activation_it = attributes.find("activation"); + if (activation_it == attributes.end()) { + continue; + } + + fused_conv_count++; + EXPECT_EQ(activation_it->second.s(), activation_op_type); + + auto activation_params_it = attributes.find("activation_params"); + if (activation_op_type == "HardSigmoid") { + ASSERT_NE(activation_params_it, attributes.end()); + ASSERT_EQ(activation_params_it->second.floats_size(), 2); + EXPECT_FLOAT_EQ(activation_params_it->second.floats(0), kHardSigmoidAlpha); + EXPECT_FLOAT_EQ(activation_params_it->second.floats(1), kHardSigmoidBeta); + } else { + EXPECT_EQ(activation_params_it, attributes.end()); + } + } + + EXPECT_EQ(fused_conv_count, 1U); + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph); + }; + + for (const auto& activation_op_type : {"Relu", "Sigmoid", "Tanh", "HardSigmoid"}) { + test_case(activation_op_type); + } +} + +TEST(NchwcOptimizerTests, ActivationSingleConsumerConvNoFusion) { + auto test_case = [&](const std::string& activation_op_type, + const std::string& domain = kOnnxDomain, + int opset_version = 13) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({1, 48, 11, 15}); auto* conv1_output_arg = helper.MakeIntermediate(); @@ -1477,12 +1549,13 @@ TEST(NchwcOptimizerTests, ActivationSingleConsumerConvGuard) { } }; - NchwcOptimizerTester(build_test_case, check_nchwc_graph); + NchwcOptimizerTester(build_test_case, check_nchwc_graph, opset_version); }; // Gelu/QuickGelu must remain separate even with a single-consumer Conv input, // because the NCHWc Conv activation fuse guard only allows a fixed subset of // activations. + test_case("Gelu", kOnnxDomain, 20); test_case("Gelu", kMSDomain); test_case("QuickGelu", kMSDomain); }