From 163f6149967fc93bde668132bfd06ff43002c1cb Mon Sep 17 00:00:00 2001
From: Shirasawa <764798966@qq.com>
Date: Mon, 23 Mar 2026 16:46:10 +0800
Subject: [PATCH 1/5] [CPU/CUDA EP] Add DeformConv op support (#27393)
### Description
This change adds support for the Deformable Convolution 2D operator
(DeformConv2D) to ONNX Runtime. The branch implements the operator
schema and registration, provides kernel implementations (CPU and
GPU/CUDA where available), implements shape inference, and adds unit and
integration tests to validate correctness and numerical parity with
reference implementations. The changes include performance-oriented
optimizations and necessary changes to build/test scripts.
### Motivation and Context
Deformable convolutions are widely used in vision models that require
spatial sampling flexibility (e.g., Deformable ConvNets, some
detection/segmentation models). Native support in ONNX Runtime enables
these models to run efficiently without custom operators or external
runtimes, broadening the set of compatible models and improving
performance and portability.
### See also
- https://onnx.ai/onnx/operators/onnx__DeformConv.html
-
https://docs.pytorch.org/vision/main/generated/torchvision.ops.deform_conv2d.html
- https://arxiv.org/abs/1811.11168
- https://arxiv.org/abs/1703.06211
-
https://github.com/pytorch/vision/blob/0f6d91d9fe514e6de2f5519114cbeb389d498b2d/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu
-
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
-
https://github.com/pytorch/vision/blob/0f6d91d9fe514e6de2f5519114cbeb389d498b2d/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp
-
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
- #22060
- #15572
- #20810
- #16903
- https://github.com/onnx/onnx/issues/5451
- https://github.com/ZhengPeng7/BiRefNet/pull/167
- https://github.com/pytorch/pytorch/issues/68910
- https://github.com/pytorch/vision/issues/2066
---
docs/OperatorKernels.md | 4 +
.../providers/cpu/cpu_execution_provider.cc | 8 +
.../core/providers/cpu/nn/deform_conv.cc | 321 ++++++
.../core/providers/cpu/nn/deform_conv.h | 24 +
.../providers/cpu/nn/deform_conv_attributes.h | 198 ++++
.../providers/cuda/cuda_execution_provider.cc | 14 +
.../core/providers/cuda/nn/deform_conv.cc | 331 ++++++
.../core/providers/cuda/nn/deform_conv.h | 27 +
.../providers/cuda/nn/deform_conv_impl.cu | 512 ++++++++++
.../core/providers/cuda/nn/deform_conv_impl.h | 63 ++
.../cpu/nn/deform_conv_expected_gen.py | 180 ++++
.../providers/cpu/nn/deform_conv_op_test.cc | 948 ++++++++++++++++++
.../test/testdata/deform_conv_test.onnx | Bin 0 -> 353 bytes
.../test/testdata/deform_conv_test_data.inc | 10 +
.../test/testdata/deform_conv_test_data.npz | Bin 0 -> 2092 bytes
.../test/testdata/nn/deform_conv_test_gen.py | 194 ++++
16 files changed, 2834 insertions(+)
create mode 100644 onnxruntime/core/providers/cpu/nn/deform_conv.cc
create mode 100644 onnxruntime/core/providers/cpu/nn/deform_conv.h
create mode 100644 onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h
create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv.cc
create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv.h
create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu
create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv_impl.h
create mode 100644 onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py
create mode 100644 onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc
create mode 100644 onnxruntime/test/testdata/deform_conv_test.onnx
create mode 100644 onnxruntime/test/testdata/deform_conv_test_data.inc
create mode 100644 onnxruntime/test/testdata/deform_conv_test_data.npz
create mode 100644 onnxruntime/test/testdata/nn/deform_conv_test_gen.py
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 39c9145a40912..25829d08206cf 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -103,6 +103,8 @@ Do not modify directly.*
|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)|
|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**
or
*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)|
|||[17, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)|
+|DeformConv|*in* X:**T**
*in* W:**T**
*in* offset:**T**
*in* B:**T**
*in* mask:**T**
*out* Y:**T**|22+|**T** = tensor(double), tensor(float)|
+|||[19, 21]|**T** = tensor(double), tensor(float)|
|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(uint8)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(uint8)|
|||[1, 10]|**T** = tensor(double), tensor(float)|
@@ -697,6 +699,8 @@ Do not modify directly.*
|Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)|
|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)|
+|DeformConv|*in* X:**T**
*in* W:**T**
*in* offset:**T**
*in* B:**T**
*in* mask:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 74b8f8e468097..9f19a20a2e680 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -1220,6 +1220,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint8_t, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Scan);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Shape);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv);
// Opset 20
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, 20, ConstantOfShape);
@@ -1316,6 +1318,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Ac
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Atanh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, ConvTranspose);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float, DeformConv);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, double, DeformConv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Det);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_float, Dropout);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_double, Dropout);
@@ -3277,6 +3281,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Resize)>,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 20
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc
new file mode 100644
index 0000000000000..f128b0e0182ad
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc
@@ -0,0 +1,321 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+//
+// CPU implementation of DeformConv (deformable convolution 2D).
+
+#include "deform_conv.h"
+
+#include
+
+#include "core/common/common.h"
+#include "core/util/math_cpuonly.h"
+#include "core/common/narrow.h"
+#include "core/util/math.h"
+
+namespace onnxruntime {
+
+namespace {
+// Bilinear interpolation at (h, w). Out-of-bounds samples return 0 (ONNX spec).
+// Indices use int (not int64_t) to reduce register pressure and improve occupancy in the hot path.
+// Limitation: height and width must not exceed INT_MAX, or casting floor(h)/floor(w) to int may overflow.
+// Acceptable in practice: deformable convolution spatial dimensions are typically well below INT_MAX.
+template
+T BilinearInterpolate(const T* in, int height, int width, T h, T w) {
+ // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case).
+ if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) {
+ return static_cast(0);
+ }
+
+ // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw.
+ const T h_floor = std::floor(h);
+ const T w_floor = std::floor(w);
+ const int h_low = static_cast(h_floor);
+ const int w_low = static_cast(w_floor);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const T lh = h - h_floor;
+ const T lw = w - w_floor;
+ const T hh = static_cast(1) - lh;
+ const T hw = static_cast(1) - lw;
+
+ // Fast path: all 4 corners in bounds (h in [0, height-1), w in [0, width-1)).
+ // Most sampling points in deformable conv fall here; avoids 4 per-corner branches.
+ // [Optimization 3]: Use unsigned comparison to avoid branch on negative height/width.
+ if (static_cast(h_low) < static_cast(height - 1) &&
+ static_cast(w_low) < static_cast(width - 1)) {
+ const int base_low = h_low * width;
+ const int base_high = h_high * width;
+ return hh * hw * in[base_low + w_low] +
+ hh * lw * in[base_low + w_high] +
+ lh * hw * in[base_high + w_low] +
+ lh * lw * in[base_high + w_high];
+ }
+
+ // Slow path: near boundary (one or more of the 4 corners may be out of bounds).
+ const int base_low = h_low * width;
+ const int base_high = h_high * width;
+ const T v1 = (h_low >= 0 && w_low >= 0) ? in[base_low + w_low] : static_cast(0);
+ const T v2 = (h_low >= 0 && w_high < width) ? in[base_low + w_high] : static_cast(0);
+ const T v3 = (h_high < height && w_low >= 0) ? in[base_high + w_low] : static_cast(0);
+ const T v4 = (h_high < height && w_high < width) ? in[base_high + w_high] : static_cast(0);
+ return hh * hw * v1 + hh * lw * v2 + lh * hw * v3 + lh * lw * v4;
+}
+
+// Deformable Im2Col for a SINGLE image.
+// Converts the input image into a matrix suitable for GEMM by sampling with learned offsets.
+// Output 'data_col' shape: [C_in * kH * kW, H_out * W_out]
+// When UseMask=false, pass nullptr for data_mask; compiler eliminates dead code for mask.
+template
+void DeformableIm2col(
+ const T* data_im, // Input image [C, H, W]
+ const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out]
+ const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (nullptr when UseMask=false)
+ int height, int width, // Input spatial dimensions (validated H*W <= INT_MAX)
+ int64_t kernel_h, int64_t kernel_w, // Kernel dimensions
+ int64_t pad_h, int64_t pad_w, // Padding (begin) for H and W
+ int64_t stride_h, int64_t stride_w, // Stride for H and W
+ int64_t dilation_h, int64_t dilation_w, // Dilation for H and W
+ int64_t channels, // Input channels
+ int64_t offset_groups, // Number of offset groups (channels shared per group)
+ int64_t height_col, int64_t width_col, // Output spatial dimensions (H_out, W_out)
+ T* data_col, // Output buffer for im2col result
+ concurrency::ThreadPool* thread_pool) {
+ const int64_t channel_per_offset_group = channels / offset_groups;
+ const int64_t kernel_size = kernel_h * kernel_w;
+ const int64_t output_size = height_col * width_col;
+
+ // Parallelize over (channel, kernel_position) so each task processes one full row of data_col.
+ // This yields channels*kernel_size tasks, better CPU utilization and cache-friendly sequential writes.
+ concurrency::ThreadPool::TryParallelFor(
+ thread_pool,
+ static_cast(channels * kernel_size),
+ static_cast(output_size) * 10.0,
+ [&](ptrdiff_t begin, ptrdiff_t end) {
+ for (ptrdiff_t idx = begin; idx < end; ++idx) {
+ // Decompose idx into (c_im, i, j): which channel and kernel position.
+ const int64_t j = static_cast(idx) % kernel_w;
+ const int64_t i = (static_cast(idx) / kernel_w) % kernel_h;
+ const int64_t c_im = static_cast(idx) / kernel_size;
+ const int64_t offset_grp = c_im / channel_per_offset_group;
+
+ // Output row: one (channel, kernel_pos) across all spatial locations.
+ T* col_ptr = data_col + static_cast(idx) * output_size;
+ const T* im_ptr = data_im + c_im * static_cast(height) * width;
+
+ // Offset tensor layout: [offset_grp, 2*kH*kW, H_out, W_out] flattened.
+ // For (i,j) we use channel indices 2*(i*kW+j) and 2*(i*kW+j)+1 for offset_h, offset_w.
+ // Precompute pointers to avoid offset_base * output_size multiplication in inner loop.
+ const int64_t offset_base =
+ offset_grp * 2 * kernel_size + 2 * (i * kernel_w + j);
+ const T* ptr_offset_h = data_offset + offset_base * output_size;
+ const T* ptr_offset_w = data_offset + (offset_base + 1) * output_size;
+
+ // Base terms for h_im, w_im: invariant in inner loop (i, j fixed).
+ const T base_h = -pad_h + static_cast(i) * dilation_h;
+ const T base_w = -pad_w + static_cast(j) * dilation_w;
+
+ // Mask pointer; only used when UseMask=true (compiler removes when false).
+ [[maybe_unused]] const T* ptr_mask = nullptr;
+ if constexpr (UseMask) {
+ ptr_mask = data_mask + (offset_grp * kernel_size + i * kernel_w + j) * output_size;
+ }
+
+ // Loop over output spatial positions.
+ for (int64_t h_col = 0; h_col < height_col; ++h_col) {
+ for (int64_t w_col = 0; w_col < width_col; ++w_col) {
+ const int64_t spatial_idx = h_col * width_col + w_col;
+
+ const T offset_h = ptr_offset_h[spatial_idx];
+ const T offset_w = ptr_offset_w[spatial_idx];
+
+ // Deformed sampling coordinates (fractional, for bilinear interpolation).
+ const T h_im = h_col * stride_h + base_h + offset_h;
+ const T w_im = w_col * stride_w + base_w + offset_w;
+
+ // Sample input at deformed location; returns 0 if out of bounds.
+ T val = BilinearInterpolate(im_ptr, height, width, h_im, w_im);
+
+ // Modulate by mask when UseMask=true; compiled away when false.
+ // Design choice: we always interpolate then multiply, rather than skip when mask==0.
+ // Rationale: (1) Skipping adds a branch; unpredictable mask values cause misprediction
+ // penalties (~15-20 cycles). (2) Straight-line code vectorizes better; conditional
+ // skip blocks SIMD. (3) Multiplying by 0 is cheap when vectorized. In typical DCN
+ // usage (moderate mask density), the unconditional path usually wins.
+ if constexpr (UseMask) {
+ val *= ptr_mask[spatial_idx];
+ }
+
+ col_ptr[spatial_idx] = val;
+ }
+ }
+ }
+ });
+}
+
+} // namespace
+
+template
+Status DeformConv::Compute(OpKernelContext* context) const {
+ const auto* X = context->Input(0);
+ const auto* W = context->Input(1);
+ const auto* offset = context->Input(2);
+ const auto* B = context->Input(3); // optional
+ const auto* mask = context->Input(4); // optional
+
+ DeformConvParams params;
+ ORT_RETURN_IF_ERROR(DeformConvValidateAndParse(
+ attrs_,
+ X->Shape(),
+ W->Shape(),
+ offset->Shape(),
+ B ? &B->Shape() : nullptr,
+ mask ? &mask->Shape() : nullptr,
+ params));
+
+ const int64_t N = params.N;
+ const int64_t C = params.C;
+ const int64_t H = params.H;
+ const int64_t W_in = params.W_in;
+ const int64_t M = params.M;
+ const int64_t kH = params.kH;
+ const int64_t kW = params.kW;
+ const int64_t pad_h = params.pad_h;
+ const int64_t pad_w = params.pad_w;
+ const int64_t stride_h = params.stride_h;
+ const int64_t stride_w = params.stride_w;
+ const int64_t dilation_h = params.dilation_h;
+ const int64_t dilation_w = params.dilation_w;
+ const int64_t group = params.group;
+ const int64_t offset_group = params.offset_group;
+ const int64_t out_h = params.out_h;
+ const int64_t out_w = params.out_w;
+ const bool use_mask = params.use_mask;
+
+ // Allocate output tensor [N, M, out_h, out_w].
+ const TensorShape Y_shape({N, M, out_h, out_w});
+ Tensor* Y = context->Output(0, Y_shape);
+ if (Y->Shape().Size() == 0) {
+ return Status::OK();
+ }
+
+ // Precompute common sizes for the im2col + GEMM pipeline.
+ const int64_t kernel_size = kH * kW;
+ const int64_t output_image_size = out_h * out_w;
+ const int64_t input_image_size = H * W_in;
+ const int64_t kernel_dim = C / group * kernel_size; // K dimension for GEMM: C/group * kH * kW
+
+ // Col buffer: shape [C*kH*kW, out_h*out_w]. Allocate per-image (process one image at a time)
+ // to reduce peak memory when N is large; im2col is implemented per-image anyway.
+ const int64_t col_buffer_size = (C * kernel_size) * output_image_size;
+
+ AllocatorPtr alloc;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
+ auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size));
+
+ const T* Xdata = X->Data();
+ const T* Wdata = W->Data();
+ const T* offset_data = offset->Data();
+ const T* mask_data = use_mask ? mask->Data() : nullptr;
+ T* Ydata = Y->MutableData();
+ const T* Bdata = (B != nullptr) ? B->Data() : nullptr;
+
+ concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
+
+ // Process each image in the batch.
+ for (int64_t n = 0; n < N; ++n) {
+ // Step 1: Deformable Im2Col for image n.
+ // Gather deformed samples into col buffer for GEMM.
+ const T* X_curr = Xdata + n * (C * input_image_size);
+ const T* offset_curr = offset_data + n * (offset_group * 2 * kernel_size * output_image_size);
+ const T* mask_curr = use_mask ? (mask_data + n * (offset_group * kernel_size * output_image_size)) : nullptr;
+ T* col_buffer_ptr = col_buffer.get();
+
+ // Dispatch to template instantiation: UseMask=true or false eliminates branch in hot loop.
+ // Note: pad_h, pad_w are begin-side paddings for coordinate mapping; pad_h_end/pad_w_end
+ // affect only output size (already baked into out_h, out_w), not im2col sampling.
+ if (use_mask) {
+ DeformableIm2col(
+ X_curr, offset_curr, mask_curr,
+ static_cast(H), static_cast(W_in), kH, kW,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ C, offset_group, out_h, out_w,
+ col_buffer_ptr, thread_pool);
+ } else {
+ DeformableIm2col(
+ X_curr, offset_curr, nullptr,
+ static_cast(H), static_cast(W_in), kH, kW,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ C, offset_group, out_h, out_w,
+ col_buffer_ptr, thread_pool);
+ }
+
+ // Step 2: GEMM for each group. Y = W * Col (per group).
+ for (int64_t g = 0; g < group; ++g) {
+ // Weight for group g: shape [M/group, C/group, kH, kW], row-major.
+ const T* weight_g = Wdata + g * (M / group) * kernel_dim;
+
+ // Col rows for group g: layout [C*kH*kW, out_h*out_w], group g spans rows [g*kernel_dim, (g+1)*kernel_dim).
+ const T* col_g = col_buffer_ptr + g * kernel_dim * output_image_size;
+
+ // Output slice for group g: [n, g*M/group:(g+1)*M/group, out_h, out_w].
+ T* Y_g = Ydata + n * M * output_image_size + g * (M / group) * output_image_size;
+
+ // GEMM: Y = W * Col. W [M/group, kernel_dim], Col [kernel_dim, output_image_size].
+ math::Gemm(
+ CblasNoTrans,
+ CblasNoTrans,
+ narrow(M / group), // M
+ narrow(output_image_size), // N
+ narrow(kernel_dim), // K
+ static_cast(1), // alpha
+ weight_g, // A
+ col_g, // B
+ static_cast(0), // beta
+ Y_g, // C
+ thread_pool,
+ nullptr); // mlas_backend_kernel_selector_config
+ }
+ }
+
+ // Step 3: Add bias if provided (broadcast over spatial dimensions).
+ if (Bdata != nullptr) {
+ int64_t total_work = N * M;
+ concurrency::ThreadPool::TryParallelFor(
+ thread_pool, static_cast(total_work), static_cast(output_image_size),
+ [&](ptrdiff_t first, ptrdiff_t last) {
+ for (ptrdiff_t idx = first; idx < last; ++idx) {
+ int64_t n = idx / M;
+ int64_t m = idx % M;
+ T* Y_ptr = Ydata + n * M * output_image_size + m * output_image_size;
+ // Eigen vectorized add: Y_ptr += Bdata[m] over all spatial positions.
+ EigenVectorArrayMap(Y_ptr, narrow(output_image_size)) += Bdata[m];
+ }
+ });
+ }
+
+ return Status::OK();
+}
+
+// Explicit template instantiation for float and double
+template class DeformConv;
+template class DeformConv;
+
+#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \
+ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
+ DeformConv, 19, 21, T, \
+ KernelDefBuilder() \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ DeformConv); \
+ ONNX_CPU_OPERATOR_TYPED_KERNEL( \
+ DeformConv, 22, T, \
+ KernelDefBuilder() \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ DeformConv)
+
+REGISTER_DEFORMCONV_KERNEL_TYPED(float)
+REGISTER_DEFORMCONV_KERNEL_TYPED(double)
+
+#undef REGISTER_DEFORMCONV_KERNEL_TYPED
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.h b/onnxruntime/core/providers/cpu/nn/deform_conv.h
new file mode 100644
index 0000000000000..c8d7763e58bcb
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/nn/deform_conv.h
@@ -0,0 +1,24 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+#include "core/framework/op_node_proto_helper.h"
+#include "deform_conv_attributes.h"
+
+namespace onnxruntime {
+
+template
+class DeformConv : public OpKernel {
+ public:
+ explicit DeformConv(const OpKernelInfo& info) : OpKernel(info), attrs_(info) {}
+
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ DeformConvAttributes attrs_;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h
new file mode 100644
index 0000000000000..8bc891bb4f377
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h
@@ -0,0 +1,198 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+#include "core/framework/tensor_shape.h"
+
+namespace onnxruntime {
+
+// Shared attributes for ONNX DeformConv (opset 19+).
+// See https://onnx.ai/onnx/operators/onnx__DeformConv.html
+// Used by both CPU and CUDA implementations (CUDA includes from here).
+struct DeformConvAttributes {
+ explicit DeformConvAttributes(const OpKernelInfo& info) {
+ // Optional attributes.
+ // If not present, they will be empty/default, and handled in Compute/ComputeInternal.
+ (void)info.GetAttrs("kernel_shape", kernel_shape);
+ (void)info.GetAttrs("strides", strides);
+ (void)info.GetAttrs("pads", pads);
+ (void)info.GetAttrs("dilations", dilations);
+ group = info.GetAttrOrDefault("group", 1);
+ offset_group = info.GetAttrOrDefault("offset_group", 1);
+ }
+
+ TensorShapeVector kernel_shape;
+ TensorShapeVector strides;
+ TensorShapeVector pads;
+ TensorShapeVector dilations;
+ int64_t group{1};
+ int64_t offset_group{1};
+};
+
+// Parsed and validated parameters from DeformConv inputs.
+// Used by both CPU and CUDA implementations.
+// Field names align with ONNX DeformConv spec: https://onnx.ai/onnx/operators/onnx__DeformConv.html
+struct DeformConvParams {
+ // Input X shape (N, C, H, W)
+ int64_t N{0}; // Batch size
+ int64_t C{0}; // Number of input channels
+ int64_t H{0}; // Input height
+ int64_t W_in{0}; // Input width (W_in to avoid collision with weight W)
+
+ // Weight W shape (oC, C/group, kH, kW)
+ int64_t M{0}; // Number of output channels (oC)
+ int64_t kH{0}; // Kernel height
+ int64_t kW{0}; // Kernel width
+
+ // Pads [x1_begin, x2_begin, x1_end, x2_end] for spatial axes H, W
+ int64_t pad_h{0};
+ int64_t pad_w{0};
+ int64_t pad_h_end{0};
+ int64_t pad_w_end{0};
+
+ // Strides and dilations along each spatial axis (default 1)
+ int64_t stride_h{1};
+ int64_t stride_w{1};
+ int64_t dilation_h{1};
+ int64_t dilation_w{1};
+
+ // Attributes: C and oC must be divisible by group; C must be divisible by offset_group
+ int64_t group{1}; // Number of groups for input/output channels
+ int64_t offset_group{1}; // Number of groups of offset
+
+ // Output Y shape (N, oC, oH, oW)
+ int64_t out_h{0}; // Output height (oH)
+ int64_t out_w{0}; // Output width (oW)
+
+ bool use_mask{false}; // Whether optional mask input is provided
+};
+
+// Validates inputs and parses attributes into params.
+// Returns Status::OK() on success; on failure, params may be partially filled.
+inline Status DeformConvValidateAndParse(
+ const DeformConvAttributes& attrs,
+ const TensorShape& X_shape,
+ const TensorShape& W_shape,
+ const TensorShape& offset_shape,
+ const TensorShape* B_shape,
+ const TensorShape* mask_shape,
+ DeformConvParams& params) {
+ ORT_RETURN_IF_NOT(X_shape.NumDimensions() == 4, "Input X must be 4D (N, C, H, W).");
+ ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D.");
+ ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D.");
+
+ // Parse input shapes
+ params.N = X_shape[0];
+ params.C = X_shape[1];
+ params.H = X_shape[2];
+ params.W_in = X_shape[3];
+ params.M = W_shape[0];
+ ORT_RETURN_IF_NOT(params.N >= 0, "Batch size N must be non-negative.");
+ ORT_RETURN_IF_NOT(params.C > 0, "Input channels C must be positive.");
+ ORT_RETURN_IF_NOT(params.M > 0, "Output channels M (oC) must be positive.");
+ ORT_RETURN_IF_NOT(W_shape[1] > 0, "Weight W must have positive in-channels (W_shape[1] = C/group).");
+
+ // Handle kernel shape inference. If kernel_shape is provided, it must match weight spatial dims
+ // to avoid GEMM using wrong K and potential out-of-bounds reads from the weight buffer.
+ const int64_t W_kH = W_shape[2];
+ const int64_t W_kW = W_shape[3];
+ if (!attrs.kernel_shape.empty()) {
+ ORT_RETURN_IF_NOT(attrs.kernel_shape.size() == 2,
+ "kernel_shape must be absent or have exactly 2 values (kH, kW) for 2D DeformConv.");
+ ORT_RETURN_IF_NOT(attrs.kernel_shape[0] == W_kH && attrs.kernel_shape[1] == W_kW,
+ "kernel_shape must match weight spatial dimensions (W_shape[2], W_shape[3]).");
+ params.kH = attrs.kernel_shape[0];
+ params.kW = attrs.kernel_shape[1];
+ } else {
+ params.kH = W_kH;
+ params.kW = W_kW;
+ }
+
+ // DeformConv is 2D-only: when an attribute is present, require exact length to avoid silently misinterpreting malformed models.
+ params.pad_h = params.pad_w = params.pad_h_end = params.pad_w_end = 0;
+ if (!attrs.pads.empty()) {
+ ORT_RETURN_IF_NOT(attrs.pads.size() == 4,
+ "pads must be absent or have exactly 4 values [pad_h_begin, pad_w_begin, pad_h_end, pad_w_end] for 2D DeformConv.");
+ params.pad_h = attrs.pads[0];
+ params.pad_w = attrs.pads[1];
+ params.pad_h_end = attrs.pads[2];
+ params.pad_w_end = attrs.pads[3];
+ ORT_RETURN_IF_NOT(params.pad_h >= 0 && params.pad_w >= 0 && params.pad_h_end >= 0 && params.pad_w_end >= 0,
+ "Pads must be non-negative (ONNX spec).");
+ }
+
+ if (!attrs.strides.empty()) {
+ ORT_RETURN_IF_NOT(attrs.strides.size() == 2,
+ "strides must be absent or have exactly 2 values [stride_h, stride_w] for 2D DeformConv.");
+ params.stride_h = attrs.strides[0];
+ params.stride_w = attrs.strides[1];
+ } else {
+ params.stride_h = params.stride_w = 1;
+ }
+
+ if (!attrs.dilations.empty()) {
+ ORT_RETURN_IF_NOT(attrs.dilations.size() == 2,
+ "dilations must be absent or have exactly 2 values [dilation_h, dilation_w] for 2D DeformConv.");
+ params.dilation_h = attrs.dilations[0];
+ params.dilation_w = attrs.dilations[1];
+ } else {
+ params.dilation_h = params.dilation_w = 1;
+ }
+ params.group = attrs.group;
+ params.offset_group = attrs.offset_group;
+ params.use_mask = (mask_shape != nullptr);
+
+ // Validate attributes
+ ORT_RETURN_IF_NOT(params.stride_h > 0 && params.stride_w > 0, "Strides must be positive.");
+ ORT_RETURN_IF_NOT(params.dilation_h > 0 && params.dilation_w > 0, "Dilations must be positive.");
+ ORT_RETURN_IF_NOT(params.kH > 0 && params.kW > 0, "Kernel shape must be positive.");
+ ORT_RETURN_IF_NOT(params.group > 0, "group must be positive");
+ ORT_RETURN_IF_NOT(params.offset_group > 0, "offset_group must be positive");
+
+ params.out_h = (params.H + params.pad_h + params.pad_h_end - params.dilation_h * (params.kH - 1) - 1) / params.stride_h + 1;
+ params.out_w = (params.W_in + params.pad_w + params.pad_w_end - params.dilation_w * (params.kW - 1) - 1) / params.stride_w + 1;
+ ORT_RETURN_IF_NOT(params.out_h >= 0 && params.out_w >= 0, "Computed output spatial size must be non-negative.");
+
+ // CPU BilinearInterpolate uses int for indices (for performance optimization); W <= INT_MAX / (H+1) covers all index math.
+ ORT_RETURN_IF_NOT(params.H >= 0 && params.W_in >= 0, "Input spatial dimensions H and W must be non-negative.");
+ ORT_RETURN_IF_NOT(params.W_in <= static_cast(INT_MAX) / (params.H + 1),
+ "Input (H+1)*W must not exceed INT_MAX (for performance optimization).");
+
+ // Validate tensor shapes (use division to avoid int64 overflow in offset_group * 2 * kH * kW).
+ ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size.");
+ const int64_t offset_block = 2 * params.kH * params.kW;
+ ORT_RETURN_IF_NOT(offset_block > 0 && offset_shape[1] % offset_block == 0 &&
+ offset_shape[1] / offset_block == params.offset_group,
+ "Offset channel count must be offset_group * 2 * kH * kW.");
+ ORT_RETURN_IF_NOT(offset_shape[2] == params.out_h, "Offset spatial height must match output oH.");
+ ORT_RETURN_IF_NOT(offset_shape[3] == params.out_w, "Offset spatial width must match output oW.");
+ ORT_RETURN_IF_NOT(params.C % params.offset_group == 0, "Input channels must be divisible by offset_group.");
+ ORT_RETURN_IF_NOT(params.C == W_shape[1] * params.group, "Input channels must match weight in channels * group.");
+ ORT_RETURN_IF_NOT(params.M % params.group == 0, "Output channels must be divisible by group.");
+
+ if (B_shape != nullptr) {
+ ORT_RETURN_IF_NOT(B_shape->NumDimensions() == 1, "Bias B must be 1D.");
+ ORT_RETURN_IF_NOT((*B_shape)[0] == params.M, "Bias B must have shape [M] (M = number of output channels).");
+ }
+
+ // Validate mask if present
+ if (params.use_mask) {
+ ORT_RETURN_IF_NOT(mask_shape->NumDimensions() == 4, "Mask must be 4D.");
+ ORT_RETURN_IF_NOT((*mask_shape)[0] == params.N, "Mask batch size must match input batch size.");
+ const int64_t mask_block = params.kH * params.kW;
+ ORT_RETURN_IF_NOT(mask_block > 0 && (*mask_shape)[1] % mask_block == 0 &&
+ (*mask_shape)[1] / mask_block == params.offset_group,
+ "Mask channel count must be offset_group * kH * kW.");
+ ORT_RETURN_IF_NOT((*mask_shape)[2] == params.out_h, "Mask spatial height must match output oH.");
+ ORT_RETURN_IF_NOT((*mask_shape)[3] == params.out_w, "Mask spatial width must match output oW.");
+ }
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index b87cf8cbc16c1..4f36789191bb2 100755
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1455,6 +1455,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, AveragePool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, AveragePool);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, DeformConv);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Cast);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Cast);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Cast);
@@ -1596,6 +1599,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Conv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Conv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, Conv);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, DeformConv);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, DeformConv);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, DeformConv);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, DeformConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSigmoid);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSigmoid);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSigmoid);
@@ -2574,6 +2581,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2706,6 +2716,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc
new file mode 100644
index 0000000000000..7a0b896acfe01
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc
@@ -0,0 +1,331 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+//
+// CUDA implementation of DeformConv (deformable convolution 2D).
+
+#include "core/providers/shared_library/provider_api.h"
+#include "deform_conv.h"
+#include "deform_conv_impl.h"
+
+#include
+
+#include "core/common/narrow.h"
+#include "core/common/span_utils.h"
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+namespace {
+
+constexpr int kMaxParallelImgs = 32;
+
+// Returns the greatest divisor of n that is <= bound. Used to choose uniform batch chunk sizes.
+// Fast path: if n % bound == 0 (common for batch 32/64/128), return immediately.
+// When n >= bound^2, linear scan from bound down is O(bound). Otherwise divisor enumeration
+// from 1 to sqrt(n) is O(sqrt(n)). Uses integer comparison (no sqrt) for branch decision.
+int GetGreatestDivisorBelowBound(int n, int bound) {
+ if (bound <= 0 || n <= 0) return 1;
+ if (n % bound == 0) return bound; // Fast path: batch is multiple of target
+
+ // n >= bound^2 <=> bound <= sqrt(n) => linear scan is cheaper
+ if (static_cast(n) >= static_cast(bound) * bound) {
+ for (int k = bound - 1; k > 1; --k) {
+ if (n % k == 0) return k;
+ }
+ } else {
+ // n < bound^2 <=> bound > sqrt(n) => divisor enumeration is cheaper
+ int best = 1;
+ for (int i = 1; static_cast(i) * i <= static_cast(n); ++i) {
+ if (n % i != 0) continue;
+ const int q = n / i;
+ if (q <= bound && q > best) best = q;
+ if (i <= bound && i > best) best = i;
+ }
+ return best;
+ }
+ return 1;
+}
+
+// Returns the maximum temp memory (bytes) allowed for DeformConv's im2col + GEMM buffers.
+// Uses a fraction of total GPU memory to avoid OOM while leaving room for weights, activations,
+// and other ops. No CUDA API is called; total_global_mem is expected from cached device props.
+//
+// Formula:
+// budget = total_global_mem * kFraction
+// return clamp(budget, kMin, kMax)
+// with kFraction = 0.1 (10%), kMin = 32 MiB, kMax = 2 GiB.
+//
+// Example results (effective_max_temp after clamp):
+// GPU | totalGlobalMem | effective_max_temp
+// -----------------|----------------|--------------------
+// A100 80GB | 80 GiB | 2 GiB (capped)
+// RTX 5080 16GB | 16 GiB | 1.6 GiB
+// RTX 4090 24GB | 24 GiB | 2 GiB (capped)
+// RTX 3080 10GB | 10 GiB | 1 GiB
+// GTX 1060 6GB | 6 GiB | 614.4 MiB
+// GTX 1050 4GB | 4 GiB | 409.6 MiB
+// Jetson 2GB | 2 GiB | 204.8 MiB
+size_t GetDeformConvEffectiveMaxTempBytes(size_t total_global_mem) {
+ constexpr double kFraction = 0.1;
+ constexpr size_t kMin = 32ULL * 1024 * 1024;
+ constexpr size_t kMax = 2ULL * 1024 * 1024 * 1024;
+ size_t budget = static_cast(static_cast(total_global_mem) * kFraction);
+ return std::clamp(budget, kMin, kMax);
+}
+
+// Returns how many images to process in parallel per batch chunk for DeformConv.
+// Chooses the largest divisor of batch size N that fits in the temp budget and does not
+// exceed kMaxParallelImgs, so that batch dimension is split evenly (no remainder).
+// Note: if N is prime and N > target_parallel_imgs, the greatest divisor <= target_parallel_imgs is 1,
+// so batching is effectively disabled (single-image chunks).
+//
+// Formulas:
+// kernel_size = kH * kW
+// output_image_size = out_h * out_w
+// bytes_per_image = output_image_size * (C * kernel_size + M / group) * sizeof(T)
+// (temp bytes per image: im2col col buffer + GEMM output buffer per output position)
+// max_parallel_imgs_mem = max(1, floor(effective_max_temp / bytes_per_image))
+// target_parallel_imgs = min(kMaxParallelImgs, max_parallel_imgs_mem)
+// return GetGreatestDivisorBelowBound(N, target_parallel_imgs)
+template
+int GetNParallelImgs(const DeformConvParams& params, size_t total_global_mem) {
+ const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(total_global_mem);
+ const int64_t kernel_size = params.kH * params.kW;
+ const int64_t output_image_size = params.out_h * params.out_w;
+ const size_t bytes_per_image = SafeInt(output_image_size) * (params.C * kernel_size + params.M / params.group) * sizeof(T);
+ const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image)));
+ const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem);
+ return GetGreatestDivisorBelowBound(static_cast(params.N), target_parallel_imgs);
+}
+
+} // namespace
+
+template
+Status DeformConv::ComputeInternal(OpKernelContext* context) const {
+ typedef typename ToCudaType::MappedType CudaT;
+
+ const auto* X = context->Input(0);
+ const auto* W = context->Input(1);
+ const auto* offset = context->Input(2);
+ const auto* B = context->Input(3);
+ const auto* mask = context->Input(4);
+
+ DeformConvParams params;
+ ORT_RETURN_IF_ERROR(DeformConvValidateAndParse(
+ attrs_,
+ X->Shape(),
+ W->Shape(),
+ offset->Shape(),
+ B ? &B->Shape() : nullptr,
+ mask ? &mask->Shape() : nullptr,
+ params));
+
+ const int64_t N = params.N;
+ const int64_t C = params.C;
+ const int64_t H = params.H;
+ const int64_t W_in = params.W_in;
+ const int64_t M = params.M;
+ const int64_t kH = params.kH;
+ const int64_t kW = params.kW;
+ const int64_t pad_h = params.pad_h;
+ const int64_t pad_w = params.pad_w;
+ const int64_t stride_h = params.stride_h;
+ const int64_t stride_w = params.stride_w;
+ const int64_t dilation_h = params.dilation_h;
+ const int64_t dilation_w = params.dilation_w;
+ const int64_t group = params.group;
+ const int64_t offset_group = params.offset_group;
+ const int64_t out_h = params.out_h;
+ const int64_t out_w = params.out_w;
+ const bool use_mask = params.use_mask;
+
+ Tensor* Y = context->Output(0, {N, M, out_h, out_w});
+ if (Y->Shape().Size() == 0) {
+ return Status::OK();
+ }
+
+ const int n_parallel_imgs = GetNParallelImgs(params, GetDeviceProp().totalGlobalMem);
+
+ const int64_t kernel_size = kH * kW;
+ const int64_t output_image_size = out_h * out_w;
+ const int64_t input_image_size = H * W_in;
+ const int64_t kernel_dim = (C / group) * kernel_size;
+
+ const int64_t col_stride = static_cast(n_parallel_imgs) * output_image_size;
+ const int64_t col_buffer_size = (C * kernel_size) * col_stride;
+
+ AllocatorPtr alloc;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
+ auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size));
+ // Removed col_transposed allocation as we avoid physical transpose.
+ auto gemm_output_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt((M / group) * col_stride));
+
+ const T* Xdata = X->Data();
+ const T* Wdata = W->Data();
+ const T* offset_data = offset->Data();
+ const T* mask_data = use_mask ? mask->Data() : nullptr;
+ T* Ydata = Y->MutableData();
+ const T* Bdata = (B != nullptr) ? B->Data() : nullptr;
+
+ cudaStream_t stream = Stream(context);
+ cublasHandle_t cublas = GetCublasHandle(context);
+ const cudaDeviceProp& device_prop = GetDeviceProp();
+ CudaT alpha = ToCudaType::FromFloat(1.0f);
+ CudaT beta = ToCudaType::FromFloat(0.0f);
+
+ for (int64_t b = 0; b < N; b += n_parallel_imgs) {
+ const int cur_parallel = static_cast(std::min(static_cast(n_parallel_imgs), N - b));
+ const int64_t cur_out_size = static_cast(cur_parallel) * output_image_size;
+
+ const T* X_block = Xdata + b * (C * input_image_size);
+ const T* offset_block = offset_data + b * (offset_group * 2 * kernel_size * output_image_size);
+ const T* mask_block = use_mask ? (mask_data + b * (offset_group * kernel_size * output_image_size)) : nullptr;
+
+ ORT_RETURN_IF_ERROR(DeformConvIm2ColImpl(
+ stream,
+ X_block,
+ offset_block,
+ mask_block,
+ col_buffer.get(),
+ cur_parallel,
+ C,
+ H,
+ W_in,
+ kH,
+ kW,
+ out_h,
+ out_w,
+ pad_h,
+ pad_w,
+ stride_h,
+ stride_w,
+ dilation_h,
+ dilation_w,
+ offset_group,
+ use_mask));
+
+ // GEMM layout trick: compute Y = W * Col without physical transpose.
+ //
+ // Our data is row-major: W [M/group, kernel_dim], Col [kernel_dim, cur_out_size], Y [M/group, cur_out_size].
+ // cuBLAS is column-major. Key insight: row-major A[M,K] in memory equals column-major A^T[K,M].
+ // We compute Y^T = Col^T * W^T by passing Col as A and W as B, both OP_N (no transpose):
+ // - Col (row [kernel_dim, cur_out_size]) -> cuBLAS interprets as col-major [cur_out_size, kernel_dim] = Col^T
+ // - W (row [M/group, kernel_dim]) -> cuBLAS interprets as col-major [kernel_dim, M/group] = W^T
+ // - C = A*B = Col^T * W^T = (W*Col)^T = Y^T; C is col-major [cur_out_size, M/group] = Y in row-major
+ //
+ // m=cur_out_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim, ldc=cur_out_size.
+ //
+ // cur_parallel==1: cur_out_size==output_image_size, C layout (pos, channel) matches NCHW Y_g[0,ch,pos] -> write
+ // directly into Y_g. Use strided batched for all groups in one call.
+ // cur_parallel>1: layouts differ -> write to gemm_output_buffer, then DeformConvCopyGemmOutputRowMajorToNCHW.
+
+ const bool gemm_writes_directly = (cur_parallel == 1);
+ if (gemm_writes_directly) {
+ // Strided batched: one call for all groups. Strides between batches:
+ const int64_t stride_col = kernel_dim * col_stride; // = kernel_dim * output_image_size when cur_parallel==1
+ const int64_t stride_weight = (M / group) * kernel_dim;
+ const int64_t stride_y = (M / group) * output_image_size;
+ CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
+ cublas,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ narrow(output_image_size),
+ narrow(M / group),
+ narrow(kernel_dim),
+ &alpha,
+ reinterpret_cast(col_buffer.get()),
+ narrow(output_image_size),
+ stride_col,
+ reinterpret_cast(Wdata),
+ narrow(kernel_dim),
+ stride_weight,
+ &beta,
+ reinterpret_cast(Ydata + b * M * output_image_size),
+ narrow(output_image_size),
+ stride_y,
+ narrow(group),
+ device_prop,
+ UseTF32()));
+ } else {
+ // cur_parallel>1: GEMM output layout differs from NCHW; write to buffer then copy per group.
+ for (int64_t g = 0; g < group; ++g) {
+ const T* W_g = Wdata + g * (M / group) * kernel_dim;
+ const T* col_g = col_buffer.get() + g * kernel_dim * col_stride;
+ T* Y_g = Ydata + b * M * output_image_size + g * (M / group) * output_image_size;
+
+ CUBLAS_RETURN_IF_ERROR((cublasGemmHelper(
+ cublas,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ narrow(cur_out_size),
+ narrow(M / group),
+ narrow(kernel_dim),
+ &alpha,
+ reinterpret_cast(col_g),
+ narrow(cur_out_size),
+ reinterpret_cast(W_g),
+ narrow(kernel_dim),
+ &beta,
+ reinterpret_cast(gemm_output_buffer.get()),
+ narrow(cur_out_size),
+ device_prop,
+ UseTF32())));
+
+ ORT_RETURN_IF_ERROR(DeformConvCopyGemmOutputRowMajorToNCHW(
+ stream,
+ gemm_output_buffer.get(),
+ Y_g,
+ M,
+ M / group,
+ output_image_size,
+ cur_parallel));
+ }
+ }
+ }
+
+ if (Bdata != nullptr) {
+ ORT_RETURN_IF_ERROR(DeformConvAddBiasImpl(stream, Ydata, Bdata, N, M, out_h, out_w));
+ }
+
+ return Status::OK();
+}
+
+#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ DeformConv, \
+ kOnnxDomain, \
+ 19, \
+ 21, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ DeformConv); \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ DeformConv, \
+ kOnnxDomain, \
+ 22, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ DeformConv)
+
+REGISTER_DEFORMCONV_KERNEL_TYPED(float);
+REGISTER_DEFORMCONV_KERNEL_TYPED(double);
+REGISTER_DEFORMCONV_KERNEL_TYPED(MLFloat16);
+
+// BFloat16 only for opset 22; opset 19-21 do not support BFloat16.
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ DeformConv,
+ kOnnxDomain,
+ 22,
+ BFloat16,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ DeformConv);
+
+#undef REGISTER_DEFORMCONV_KERNEL_TYPED
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.h b/onnxruntime/core/providers/cuda/nn/deform_conv.h
new file mode 100644
index 0000000000000..fa564641d4b98
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/nn/deform_conv.h
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+#include "core/providers/cpu/nn/deform_conv_attributes.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+template
+class DeformConv final : public CudaKernel {
+ public:
+ explicit DeformConv(const OpKernelInfo& info) : CudaKernel(info), attrs_(info) {}
+
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ DeformConvAttributes attrs_;
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeformConv);
+};
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu
new file mode 100644
index 0000000000000..7b3666fca810b
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu
@@ -0,0 +1,512 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+//
+// CUDA implementation of DeformConv: deformable im2col kernel + bilinear interpolation.
+// Reference: torchvision deform_conv2d_kernel.cu, ONNX DeformConv spec.
+
+#include "deform_conv_impl.h"
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "core/providers/cuda/shared_inc/fast_divmod.h"
+#include "core/common/float16.h"
+#include
+#include
+#include
+
+namespace onnxruntime {
+namespace cuda {
+
+namespace {
+
+constexpr int kDeformConvThreadsPerBlock = 256;
+
+template
+struct DeformConvKSize {
+ static constexpr int value = N;
+};
+
+// Calculate grid size with a safety limit to prevent overflow.
+// Since we use grid-stride loops in kernels, limiting the grid size is safe.
+inline int GetGridSize(size_t n, size_t threads_per_block) {
+ size_t blocks_needed = (n + threads_per_block - 1) / threads_per_block;
+ return static_cast(std::min(blocks_needed, static_cast(std::numeric_limits::max())));
+}
+
+// __ldg has no overload for BFloat16*; use 16-bit load + FromBits. Other types use __ldg directly.
+template
+__device__ __inline__ T DeformConvLdg(const T* p) {
+ return __ldg(p);
+}
+template <>
+__device__ __inline__ BFloat16 DeformConvLdg(const BFloat16* p) {
+ return BFloat16::FromBits(__ldg(reinterpret_cast(p)));
+}
+
+// Traits for bilinear interpolation math:
+// - ComputeT: type used for coordinate/weight math (float for half/BFloat16, T otherwise)
+// - Load: load one element and convert to ComputeT
+// - ToResult: convert ComputeT result back to T
+// - Zero: zero value of T
+template
+struct DeformConvBilinearTraits {
+ using ComputeT = T;
+
+ __device__ static __inline__ ComputeT Load(const T* p) {
+ return __ldg(p);
+ }
+
+ __device__ static __inline__ T ToResult(ComputeT v) {
+ return v;
+ }
+
+ __device__ static __inline__ T Zero() {
+ return static_cast(0);
+ }
+};
+
+template <>
+struct DeformConvBilinearTraits {
+ using ComputeT = float;
+
+ __device__ static __inline__ ComputeT Load(const half* p) {
+ return __half2float(__ldg(p));
+ }
+
+ __device__ static __inline__ half ToResult(ComputeT v) {
+ return __float2half(v);
+ }
+
+ __device__ static __inline__ half Zero() {
+ return __float2half(0.0f);
+ }
+};
+
+template <>
+struct DeformConvBilinearTraits {
+ using ComputeT = float;
+
+ __device__ static __inline__ ComputeT Load(const BFloat16* p) {
+ return static_cast(DeformConvLdg(p));
+ }
+
+ __device__ static __inline__ BFloat16 ToResult(ComputeT v) {
+ return BFloat16(v);
+ }
+
+ __device__ static __inline__ BFloat16 Zero() {
+ return BFloat16(0.0f);
+ }
+};
+
+// Bilinear interpolation at (h, w). Returns 0 if out of bounds (ONNX spec).
+// Indices h_low, w_low, h_high, w_high use int (not int64_t) to reduce register pressure and
+// improve occupancy in the hot path. Limitation: (H+1)*W must not exceed INT_MAX; this is
+// validated on the host side in DeformConvValidateAndParse to guarantee index math in int
+// does not overflow. For half/BFloat16, coordinate and weight math use float via
+// DeformConvBilinearTraits to avoid precision loss. We keep floor() results in CoordT and
+// cast to int only for indices (h_low/w_low), which avoids unnecessary CoordT->int->CoordT
+// round trips when computing lh/lw/hh/hw.
+template
+__device__ __inline__ T BilinearInterpolate(
+ const T* in,
+ int height,
+ int width,
+ typename DeformConvBilinearTraits::ComputeT h,
+ typename DeformConvBilinearTraits::ComputeT w) {
+ using Traits = DeformConvBilinearTraits;
+ using CoordT = typename Traits::ComputeT;
+
+ // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case).
+ if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) {
+ return Traits::Zero();
+ }
+
+ // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw.
+ CoordT h_floor = _Floor(h);
+ CoordT w_floor = _Floor(w);
+ int h_low = static_cast(h_floor);
+ int w_low = static_cast(w_floor);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ CoordT lh = h - h_floor;
+ CoordT lw = w - w_floor;
+ CoordT hh = static_cast(1) - lh;
+ CoordT hw = static_cast(1) - lw;
+
+ // [Optimization 3]: Avoid a second multiply for base_high.
+ // Original code computed both bases as:
+ // base_low = h_low * width;
+ // base_high = h_high * width;
+ // Since h_high = h_low + 1, we can rewrite base_high as base_low + width and
+ // save one integer multiply in the hot path:
+ // base_low = h_low * width;
+ // base_high = base_low + width;
+ int base_low = h_low * width;
+ int base_high = base_low + width;
+
+ CoordT v1 = (h_low >= 0 && w_low >= 0) ? Traits::Load(in + base_low + w_low) : static_cast(0);
+ CoordT v2 = (h_low >= 0 && w_high < width) ? Traits::Load(in + base_low + w_high) : static_cast(0);
+ CoordT v3 = (h_high < height && w_low >= 0) ? Traits::Load(in + base_high + w_low) : static_cast(0);
+ CoordT v4 = (h_high < height && w_high < width) ? Traits::Load(in + base_high + w_high) : static_cast(0);
+
+ CoordT w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ return Traits::ToResult(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+}
+
+// kH/kW = -1 means dynamic (runtime); >= 0 means compile-time constant for loop unrolling.
+template
+__global__ void DeformableIm2ColKernel(
+ IndexT num_kernels,
+ const T* __restrict__ input,
+ const T* __restrict__ offset,
+ const T* __restrict__ mask,
+ int64_t height,
+ int64_t width,
+ int64_t weight_h,
+ int64_t weight_w,
+ int64_t pad_h,
+ int64_t pad_w,
+ int64_t stride_h,
+ int64_t stride_w,
+ int64_t dilation_h,
+ int64_t dilation_w,
+ int64_t channels,
+ int64_t offset_group,
+ DivMod out_h_div,
+ DivMod out_w_div,
+ DivMod parallel_imgs_div,
+ DivMod channel_per_offset_grp_div,
+ bool use_mask,
+ T* __restrict__ data_col) {
+ constexpr bool is_fixed = (kH >= 0 && kW >= 0);
+ const int64_t h_dim = is_fixed ? kH : weight_h;
+ const int64_t w_dim = is_fixed ? kW : weight_w;
+
+ // Reconstruct dimensions from DivMod objects
+ const int64_t out_h = out_h_div.d_;
+ const int64_t out_w = out_w_div.d_;
+ const int64_t parallel_imgs = parallel_imgs_div.d_;
+
+ const int64_t out_size = out_h * out_w;
+ // The stride for data_col is (parallel_imgs * out_h * out_w)
+ const int64_t col_stride = parallel_imgs * out_size;
+
+ using CoordT = typename DeformConvBilinearTraits::ComputeT;
+
+ for (IndexT index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) {
+ IndexT val = index;
+ IndexT out_x, out_y, out_b, in_c;
+
+ // Fast division/modulo to recover coordinates
+ out_w_div.divmod(val, val, out_x);
+ out_h_div.divmod(val, val, out_y);
+ parallel_imgs_div.divmod(val, in_c, out_b);
+
+ // [Optimization 3] Avoid expensive division if offset_group is 1 (very common case).
+ IndexT offset_grp = 0;
+ if (offset_group > 1) {
+ IndexT dummy;
+ channel_per_offset_grp_div.divmod(in_c, offset_grp, dummy);
+ }
+
+ // [Optimization 2] Common Subexpression Elimination (CSE) & Pointer Arithmetic
+ // Pre-calculate base pointers to reduce integer arithmetic inside the inner loops.
+
+ // 1. Input pointer base for this batch and channel.
+ const T* input_ptr = input + static_cast(out_b) * (channels * height * width) + static_cast(in_c) * (height * width);
+
+ // 2. Spatial index in the output feature map.
+ const int64_t spatial_idx = static_cast(out_y) * out_w + static_cast(out_x);
+
+ // 3. Offset pointer base calculation.
+ // Layout: (N, offset_groups, 2*KH*KW, OH, OW)
+ // We pre-calculate the pointer to the start of the specific (n, g) block, plus spatial_idx.
+ const int64_t offset_group_block_size = 2 * h_dim * w_dim * out_size;
+ const T* offset_ptr_base = offset + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * offset_group_block_size + spatial_idx;
+
+ // 4. Mask pointer base calculation (if used).
+ // Layout: (N, offset_groups, KH*KW, OH, OW)
+ const T* mask_ptr_base = nullptr;
+ if (use_mask) {
+ const int64_t mask_group_block_size = h_dim * w_dim * out_size;
+ mask_ptr_base = mask + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * mask_group_block_size + spatial_idx;
+ }
+
+ // 5. Output pointer base calculation.
+ // data_col Layout: (C * KH * KW, N * OH * OW)
+ // The current thread writes to the column `c_col` = (b * OH * OW) + spatial_idx.
+ // The starting row for this channel is `in_c * KH * KW`.
+ const int64_t c_col = static_cast(out_b) * out_size + spatial_idx;
+ T* data_col_ptr_base = data_col + (static_cast(in_c) * h_dim * w_dim) * col_stride + c_col;
+
+ // 6. Pre-calculate invariant coordinate parts.
+ // Use float for coordinate math when T is half or BFloat16 to avoid precision loss.
+ const CoordT base_h_im = static_cast(out_y * stride_h - pad_h);
+ const CoordT base_w_im = static_cast(out_x * stride_w - pad_w);
+
+ auto process_kernel_point = [&](int64_t i, int64_t j) {
+ const int64_t kernel_idx = i * w_dim + j;
+ T mask_val = static_cast(1);
+ if (use_mask) {
+ // Access mask using pre-calculated base and stride.
+ mask_val = DeformConvLdg(mask_ptr_base + kernel_idx * out_size);
+ }
+
+ // Calculate offset pointers relative to the base.
+ // The offset tensor stores (y_offset, x_offset) pairs for each kernel weight.
+ // Stride between y_offset and x_offset is `out_size`.
+ const int64_t offset_offset_idx = (2 * kernel_idx) * out_size;
+
+ const CoordT offset_h = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx));
+ const CoordT offset_w = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx + out_size));
+
+ const CoordT h_im = base_h_im + static_cast(i * dilation_h) + offset_h;
+ const CoordT w_im = base_w_im + static_cast(j * dilation_w) + offset_w;
+
+ // height/width are validated on host (DeformConvValidateAndParse) so int is safe here.
+ T val = BilinearInterpolate(input_ptr,
+ static_cast(height),
+ static_cast(width),
+ h_im,
+ w_im);
+
+ // Match CPU path: always interpolate then apply mask to keep branch-free hot loop.
+ data_col_ptr_base[kernel_idx * col_stride] = val * mask_val;
+ };
+
+ if constexpr (is_fixed) {
+#pragma unroll
+ for (int i = 0; i < kH; ++i) {
+#pragma unroll
+ for (int j = 0; j < kW; ++j) {
+ process_kernel_point(i, j);
+ }
+ }
+ } else {
+ for (int64_t i = 0; i < weight_h; ++i) {
+ for (int64_t j = 0; j < weight_w; ++j) {
+ process_kernel_point(i, j);
+ }
+ }
+ }
+ }
+}
+
+// Bias add: Y[n,m,oh,ow] += B[m]. Layout NCHW.
+template
+__global__ void DeformConvAddBiasKernel(
+ T* Y,
+ const T* B,
+ DivMod spatial_div, // For dividing by (H * W)
+ DivMod channel_div, // For dividing by M (channel count)
+ int64_t total_elements) {
+ for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += blockDim.x * gridDim.x) {
+ int64_t val = idx;
+ int64_t batch_channel_idx, pixel_idx;
+
+ // 1. First decomposition: decompose idx into (batch_channel_idx, pixel_idx)
+ // Equivalent to: batch_channel_idx = idx / (H*W); pixel_idx = idx % (H*W);
+ spatial_div.divmod(val, batch_channel_idx, pixel_idx);
+
+ int64_t batch_idx, channel_idx;
+
+ // 2. Second decomposition: decompose batch_channel_idx into (batch_idx, channel_idx)
+ // Equivalent to: channel_idx = batch_channel_idx % M;
+ // We only need channel_idx (i.e. m)
+ channel_div.divmod(batch_channel_idx, batch_idx, channel_idx);
+ (void)batch_idx; // Only channel_idx is needed
+
+ // channel_idx is what we need (i.e. m)
+ Y[idx] += DeformConvLdg(B + channel_idx);
+ }
+}
+
+// Copy GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) into NCHW Y_g.
+// src(c, j) with j = b_idx*output_image_size + pos -> dst[b_idx*M*output_image_size + c*output_image_size + pos].
+template
+__global__ void CopyGemmOutputRowMajorToNCHWKernel(
+ const T* __restrict__ src,
+ T* __restrict__ dst,
+ int64_t M,
+ int64_t M_per_group,
+ int64_t output_image_size,
+ int64_t cur_parallel) {
+ int64_t total = cur_parallel * M_per_group * output_image_size;
+ for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) {
+ int64_t pos = idx % output_image_size;
+ int64_t c = (idx / output_image_size) % M_per_group;
+ int64_t b_idx = idx / (output_image_size * M_per_group);
+ int64_t j = b_idx * output_image_size + pos;
+ // src index for row-major: c * (cur_parallel * output_image_size) + j
+ dst[b_idx * M * output_image_size + c * output_image_size + pos] = src[c * (cur_parallel * output_image_size) + j];
+ }
+}
+
+} // namespace
+
+template
+Status DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) {
+ int64_t total = N * M * out_h * out_w;
+ if (total <= 0) return Status::OK();
+
+ // 1. Prepare divisor
+ int64_t out_size = out_h * out_w;
+
+ // 2. Create FastDivMod object (note: ensure int64_t version of DivMod is used here)
+ DivMod spatial_div(out_size);
+ DivMod channel_div(M);
+
+ int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock);
+
+ // 3. Pass DivMod objects
+ DeformConvAddBiasKernel<<>>(
+ Y,
+ B,
+ spatial_div,
+ channel_div,
+ total);
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template
+Status DeformConvCopyGemmOutputRowMajorToNCHW(
+ cudaStream_t stream,
+ const T* gemm_output,
+ T* Y_g,
+ int64_t M,
+ int64_t M_per_group,
+ int64_t output_image_size,
+ int64_t cur_parallel) {
+ int64_t total = cur_parallel * M_per_group * output_image_size;
+ if (total <= 0) return Status::OK();
+ int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock);
+ CopyGemmOutputRowMajorToNCHWKernel<<>>(
+ gemm_output, Y_g, M, M_per_group, output_image_size, cur_parallel);
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template
+Status DeformConvIm2ColImpl(
+ cudaStream_t stream,
+ const T* input,
+ const T* offset,
+ const T* mask,
+ T* col_buffer,
+ int64_t parallel_imgs,
+ int64_t C,
+ int64_t H,
+ int64_t W,
+ int64_t kH,
+ int64_t kW,
+ int64_t out_h,
+ int64_t out_w,
+ int64_t pad_h,
+ int64_t pad_w,
+ int64_t stride_h,
+ int64_t stride_w,
+ int64_t dilation_h,
+ int64_t dilation_w,
+ int64_t offset_group,
+ bool use_mask) {
+ const int64_t num_kernels = static_cast(C) * out_h * out_w * parallel_imgs;
+ if (num_kernels <= 0) {
+ return Status::OK();
+ }
+
+ const int64_t col_numel = static_cast(C) * kH * kW * parallel_imgs * out_h * out_w;
+ const bool use_64bit = (num_kernels > static_cast(std::numeric_limits::max())) ||
+ (col_numel > static_cast(std::numeric_limits::max()));
+
+ int blocks = GetGridSize(static_cast(num_kernels), kDeformConvThreadsPerBlock);
+
+ auto launch = [&](auto kH_tag, auto kW_tag) {
+ constexpr int KH = decltype(kH_tag)::value;
+ constexpr int KW = decltype(kW_tag)::value;
+ if (use_64bit) {
+ DeformableIm2ColKernel<<>>(
+ num_kernels, input, offset, mask, H, W, kH, kW, pad_h, pad_w,
+ stride_h, stride_w, dilation_h, dilation_w, C, offset_group,
+ DivMod(out_h), DivMod(out_w), DivMod(parallel_imgs),
+ DivMod(C / offset_group), use_mask, col_buffer);
+ } else {
+ DeformableIm2ColKernel<<>>(
+ static_cast(num_kernels), input, offset, mask, H, W, kH, kW, pad_h, pad_w,
+ stride_h, stride_w, dilation_h, dilation_w, C, offset_group,
+ DivMod(static_cast(out_h)),
+ DivMod(static_cast(out_w)),
+ DivMod