diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 39c9145a40912..625cc4e09ca13 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -103,6 +103,8 @@ Do not modify directly.* |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |||[17, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DeformConv|*in* X:**T**
*in* W:**T**
*in* offset:**T**
*in* B:**T**
*in* mask:**T**
*out* Y:**T**|22+|**T** = tensor(double), tensor(float)| +|||[19, 21]|**T** = tensor(double), tensor(float)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(uint8)| |||[1, 10]|**T** = tensor(double), tensor(float)| @@ -697,6 +699,8 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| +|DeformConv|*in* X:**T**
*in* W:**T**
*in* offset:**T**
*in* B:**T**
*in* mask:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -843,7 +847,12 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|25+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||24|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||23|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[21, 22]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[19, 20]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||18|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -902,7 +911,9 @@ Do not modify directly.* |||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| +|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|22+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(int64)| +|||[16, 21]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(int64)| +|||[10, 15]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |RotaryEmbedding|*in* X:**T**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**M**
*out* Y:**T**|23+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 74b8f8e468097..9f19a20a2e680 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1220,6 +1220,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv); // Opset 20 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, 20, ConstantOfShape); @@ -1316,6 +1318,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Ac class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Atanh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, ConvTranspose); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, double, DeformConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Det); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_float, Dropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_double, Dropout); @@ -3277,6 +3281,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Resize)>, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 20 BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc new file mode 100644 index 0000000000000..f128b0e0182ad --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -0,0 +1,321 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CPU implementation of DeformConv (deformable convolution 2D). + +#include "deform_conv.h" + +#include + +#include "core/common/common.h" +#include "core/util/math_cpuonly.h" +#include "core/common/narrow.h" +#include "core/util/math.h" + +namespace onnxruntime { + +namespace { +// Bilinear interpolation at (h, w). Out-of-bounds samples return 0 (ONNX spec). +// Indices use int (not int64_t) to reduce register pressure and improve occupancy in the hot path. +// Limitation: height and width must not exceed INT_MAX, or casting floor(h)/floor(w) to int may overflow. +// Acceptable in practice: deformable convolution spatial dimensions are typically well below INT_MAX. +template +T BilinearInterpolate(const T* in, int height, int width, T h, T w) { + // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case). + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return static_cast(0); + } + + // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw. + const T h_floor = std::floor(h); + const T w_floor = std::floor(w); + const int h_low = static_cast(h_floor); + const int w_low = static_cast(w_floor); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const T lh = h - h_floor; + const T lw = w - w_floor; + const T hh = static_cast(1) - lh; + const T hw = static_cast(1) - lw; + + // Fast path: all 4 corners in bounds (h in [0, height-1), w in [0, width-1)). + // Most sampling points in deformable conv fall here; avoids 4 per-corner branches. + // [Optimization 3]: Use unsigned comparison to avoid branch on negative height/width. + if (static_cast(h_low) < static_cast(height - 1) && + static_cast(w_low) < static_cast(width - 1)) { + const int base_low = h_low * width; + const int base_high = h_high * width; + return hh * hw * in[base_low + w_low] + + hh * lw * in[base_low + w_high] + + lh * hw * in[base_high + w_low] + + lh * lw * in[base_high + w_high]; + } + + // Slow path: near boundary (one or more of the 4 corners may be out of bounds). + const int base_low = h_low * width; + const int base_high = h_high * width; + const T v1 = (h_low >= 0 && w_low >= 0) ? in[base_low + w_low] : static_cast(0); + const T v2 = (h_low >= 0 && w_high < width) ? in[base_low + w_high] : static_cast(0); + const T v3 = (h_high < height && w_low >= 0) ? in[base_high + w_low] : static_cast(0); + const T v4 = (h_high < height && w_high < width) ? in[base_high + w_high] : static_cast(0); + return hh * hw * v1 + hh * lw * v2 + lh * hw * v3 + lh * lw * v4; +} + +// Deformable Im2Col for a SINGLE image. +// Converts the input image into a matrix suitable for GEMM by sampling with learned offsets. +// Output 'data_col' shape: [C_in * kH * kW, H_out * W_out] +// When UseMask=false, pass nullptr for data_mask; compiler eliminates dead code for mask. +template +void DeformableIm2col( + const T* data_im, // Input image [C, H, W] + const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out] + const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (nullptr when UseMask=false) + int height, int width, // Input spatial dimensions (validated H*W <= INT_MAX) + int64_t kernel_h, int64_t kernel_w, // Kernel dimensions + int64_t pad_h, int64_t pad_w, // Padding (begin) for H and W + int64_t stride_h, int64_t stride_w, // Stride for H and W + int64_t dilation_h, int64_t dilation_w, // Dilation for H and W + int64_t channels, // Input channels + int64_t offset_groups, // Number of offset groups (channels shared per group) + int64_t height_col, int64_t width_col, // Output spatial dimensions (H_out, W_out) + T* data_col, // Output buffer for im2col result + concurrency::ThreadPool* thread_pool) { + const int64_t channel_per_offset_group = channels / offset_groups; + const int64_t kernel_size = kernel_h * kernel_w; + const int64_t output_size = height_col * width_col; + + // Parallelize over (channel, kernel_position) so each task processes one full row of data_col. + // This yields channels*kernel_size tasks, better CPU utilization and cache-friendly sequential writes. + concurrency::ThreadPool::TryParallelFor( + thread_pool, + static_cast(channels * kernel_size), + static_cast(output_size) * 10.0, + [&](ptrdiff_t begin, ptrdiff_t end) { + for (ptrdiff_t idx = begin; idx < end; ++idx) { + // Decompose idx into (c_im, i, j): which channel and kernel position. + const int64_t j = static_cast(idx) % kernel_w; + const int64_t i = (static_cast(idx) / kernel_w) % kernel_h; + const int64_t c_im = static_cast(idx) / kernel_size; + const int64_t offset_grp = c_im / channel_per_offset_group; + + // Output row: one (channel, kernel_pos) across all spatial locations. + T* col_ptr = data_col + static_cast(idx) * output_size; + const T* im_ptr = data_im + c_im * static_cast(height) * width; + + // Offset tensor layout: [offset_grp, 2*kH*kW, H_out, W_out] flattened. + // For (i,j) we use channel indices 2*(i*kW+j) and 2*(i*kW+j)+1 for offset_h, offset_w. + // Precompute pointers to avoid offset_base * output_size multiplication in inner loop. + const int64_t offset_base = + offset_grp * 2 * kernel_size + 2 * (i * kernel_w + j); + const T* ptr_offset_h = data_offset + offset_base * output_size; + const T* ptr_offset_w = data_offset + (offset_base + 1) * output_size; + + // Base terms for h_im, w_im: invariant in inner loop (i, j fixed). + const T base_h = -pad_h + static_cast(i) * dilation_h; + const T base_w = -pad_w + static_cast(j) * dilation_w; + + // Mask pointer; only used when UseMask=true (compiler removes when false). + [[maybe_unused]] const T* ptr_mask = nullptr; + if constexpr (UseMask) { + ptr_mask = data_mask + (offset_grp * kernel_size + i * kernel_w + j) * output_size; + } + + // Loop over output spatial positions. + for (int64_t h_col = 0; h_col < height_col; ++h_col) { + for (int64_t w_col = 0; w_col < width_col; ++w_col) { + const int64_t spatial_idx = h_col * width_col + w_col; + + const T offset_h = ptr_offset_h[spatial_idx]; + const T offset_w = ptr_offset_w[spatial_idx]; + + // Deformed sampling coordinates (fractional, for bilinear interpolation). + const T h_im = h_col * stride_h + base_h + offset_h; + const T w_im = w_col * stride_w + base_w + offset_w; + + // Sample input at deformed location; returns 0 if out of bounds. + T val = BilinearInterpolate(im_ptr, height, width, h_im, w_im); + + // Modulate by mask when UseMask=true; compiled away when false. + // Design choice: we always interpolate then multiply, rather than skip when mask==0. + // Rationale: (1) Skipping adds a branch; unpredictable mask values cause misprediction + // penalties (~15-20 cycles). (2) Straight-line code vectorizes better; conditional + // skip blocks SIMD. (3) Multiplying by 0 is cheap when vectorized. In typical DCN + // usage (moderate mask density), the unconditional path usually wins. + if constexpr (UseMask) { + val *= ptr_mask[spatial_idx]; + } + + col_ptr[spatial_idx] = val; + } + } + } + }); +} + +} // namespace + +template +Status DeformConv::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto* offset = context->Input(2); + const auto* B = context->Input(3); // optional + const auto* mask = context->Input(4); // optional + + DeformConvParams params; + ORT_RETURN_IF_ERROR(DeformConvValidateAndParse( + attrs_, + X->Shape(), + W->Shape(), + offset->Shape(), + B ? &B->Shape() : nullptr, + mask ? &mask->Shape() : nullptr, + params)); + + const int64_t N = params.N; + const int64_t C = params.C; + const int64_t H = params.H; + const int64_t W_in = params.W_in; + const int64_t M = params.M; + const int64_t kH = params.kH; + const int64_t kW = params.kW; + const int64_t pad_h = params.pad_h; + const int64_t pad_w = params.pad_w; + const int64_t stride_h = params.stride_h; + const int64_t stride_w = params.stride_w; + const int64_t dilation_h = params.dilation_h; + const int64_t dilation_w = params.dilation_w; + const int64_t group = params.group; + const int64_t offset_group = params.offset_group; + const int64_t out_h = params.out_h; + const int64_t out_w = params.out_w; + const bool use_mask = params.use_mask; + + // Allocate output tensor [N, M, out_h, out_w]. + const TensorShape Y_shape({N, M, out_h, out_w}); + Tensor* Y = context->Output(0, Y_shape); + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + // Precompute common sizes for the im2col + GEMM pipeline. + const int64_t kernel_size = kH * kW; + const int64_t output_image_size = out_h * out_w; + const int64_t input_image_size = H * W_in; + const int64_t kernel_dim = C / group * kernel_size; // K dimension for GEMM: C/group * kH * kW + + // Col buffer: shape [C*kH*kW, out_h*out_w]. Allocate per-image (process one image at a time) + // to reduce peak memory when N is large; im2col is implemented per-image anyway. + const int64_t col_buffer_size = (C * kernel_size) * output_image_size; + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size)); + + const T* Xdata = X->Data(); + const T* Wdata = W->Data(); + const T* offset_data = offset->Data(); + const T* mask_data = use_mask ? mask->Data() : nullptr; + T* Ydata = Y->MutableData(); + const T* Bdata = (B != nullptr) ? B->Data() : nullptr; + + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + + // Process each image in the batch. + for (int64_t n = 0; n < N; ++n) { + // Step 1: Deformable Im2Col for image n. + // Gather deformed samples into col buffer for GEMM. + const T* X_curr = Xdata + n * (C * input_image_size); + const T* offset_curr = offset_data + n * (offset_group * 2 * kernel_size * output_image_size); + const T* mask_curr = use_mask ? (mask_data + n * (offset_group * kernel_size * output_image_size)) : nullptr; + T* col_buffer_ptr = col_buffer.get(); + + // Dispatch to template instantiation: UseMask=true or false eliminates branch in hot loop. + // Note: pad_h, pad_w are begin-side paddings for coordinate mapping; pad_h_end/pad_w_end + // affect only output size (already baked into out_h, out_w), not im2col sampling. + if (use_mask) { + DeformableIm2col( + X_curr, offset_curr, mask_curr, + static_cast(H), static_cast(W_in), kH, kW, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + C, offset_group, out_h, out_w, + col_buffer_ptr, thread_pool); + } else { + DeformableIm2col( + X_curr, offset_curr, nullptr, + static_cast(H), static_cast(W_in), kH, kW, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + C, offset_group, out_h, out_w, + col_buffer_ptr, thread_pool); + } + + // Step 2: GEMM for each group. Y = W * Col (per group). + for (int64_t g = 0; g < group; ++g) { + // Weight for group g: shape [M/group, C/group, kH, kW], row-major. + const T* weight_g = Wdata + g * (M / group) * kernel_dim; + + // Col rows for group g: layout [C*kH*kW, out_h*out_w], group g spans rows [g*kernel_dim, (g+1)*kernel_dim). + const T* col_g = col_buffer_ptr + g * kernel_dim * output_image_size; + + // Output slice for group g: [n, g*M/group:(g+1)*M/group, out_h, out_w]. + T* Y_g = Ydata + n * M * output_image_size + g * (M / group) * output_image_size; + + // GEMM: Y = W * Col. W [M/group, kernel_dim], Col [kernel_dim, output_image_size]. + math::Gemm( + CblasNoTrans, + CblasNoTrans, + narrow(M / group), // M + narrow(output_image_size), // N + narrow(kernel_dim), // K + static_cast(1), // alpha + weight_g, // A + col_g, // B + static_cast(0), // beta + Y_g, // C + thread_pool, + nullptr); // mlas_backend_kernel_selector_config + } + } + + // Step 3: Add bias if provided (broadcast over spatial dimensions). + if (Bdata != nullptr) { + int64_t total_work = N * M; + concurrency::ThreadPool::TryParallelFor( + thread_pool, static_cast(total_work), static_cast(output_image_size), + [&](ptrdiff_t first, ptrdiff_t last) { + for (ptrdiff_t idx = first; idx < last; ++idx) { + int64_t n = idx / M; + int64_t m = idx % M; + T* Y_ptr = Ydata + n * M * output_image_size + m * output_image_size; + // Eigen vectorized add: Y_ptr += Bdata[m] over all spatial positions. + EigenVectorArrayMap(Y_ptr, narrow(output_image_size)) += Bdata[m]; + } + }); + } + + return Status::OK(); +} + +// Explicit template instantiation for float and double +template class DeformConv; +template class DeformConv; + +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + DeformConv, 19, 21, T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv); \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + DeformConv, 22, T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv) + +REGISTER_DEFORMCONV_KERNEL_TYPED(float) +REGISTER_DEFORMCONV_KERNEL_TYPED(double) + +#undef REGISTER_DEFORMCONV_KERNEL_TYPED + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.h b/onnxruntime/core/providers/cpu/nn/deform_conv.h new file mode 100644 index 0000000000000..c8d7763e58bcb --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/framework/op_node_proto_helper.h" +#include "deform_conv_attributes.h" + +namespace onnxruntime { + +template +class DeformConv : public OpKernel { + public: + explicit DeformConv(const OpKernelInfo& info) : OpKernel(info), attrs_(info) {} + + Status Compute(OpKernelContext* context) const override; + + private: + DeformConvAttributes attrs_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h new file mode 100644 index 0000000000000..8bc891bb4f377 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensor_shape.h" + +namespace onnxruntime { + +// Shared attributes for ONNX DeformConv (opset 19+). +// See https://onnx.ai/onnx/operators/onnx__DeformConv.html +// Used by both CPU and CUDA implementations (CUDA includes from here). +struct DeformConvAttributes { + explicit DeformConvAttributes(const OpKernelInfo& info) { + // Optional attributes. + // If not present, they will be empty/default, and handled in Compute/ComputeInternal. + (void)info.GetAttrs("kernel_shape", kernel_shape); + (void)info.GetAttrs("strides", strides); + (void)info.GetAttrs("pads", pads); + (void)info.GetAttrs("dilations", dilations); + group = info.GetAttrOrDefault("group", 1); + offset_group = info.GetAttrOrDefault("offset_group", 1); + } + + TensorShapeVector kernel_shape; + TensorShapeVector strides; + TensorShapeVector pads; + TensorShapeVector dilations; + int64_t group{1}; + int64_t offset_group{1}; +}; + +// Parsed and validated parameters from DeformConv inputs. +// Used by both CPU and CUDA implementations. +// Field names align with ONNX DeformConv spec: https://onnx.ai/onnx/operators/onnx__DeformConv.html +struct DeformConvParams { + // Input X shape (N, C, H, W) + int64_t N{0}; // Batch size + int64_t C{0}; // Number of input channels + int64_t H{0}; // Input height + int64_t W_in{0}; // Input width (W_in to avoid collision with weight W) + + // Weight W shape (oC, C/group, kH, kW) + int64_t M{0}; // Number of output channels (oC) + int64_t kH{0}; // Kernel height + int64_t kW{0}; // Kernel width + + // Pads [x1_begin, x2_begin, x1_end, x2_end] for spatial axes H, W + int64_t pad_h{0}; + int64_t pad_w{0}; + int64_t pad_h_end{0}; + int64_t pad_w_end{0}; + + // Strides and dilations along each spatial axis (default 1) + int64_t stride_h{1}; + int64_t stride_w{1}; + int64_t dilation_h{1}; + int64_t dilation_w{1}; + + // Attributes: C and oC must be divisible by group; C must be divisible by offset_group + int64_t group{1}; // Number of groups for input/output channels + int64_t offset_group{1}; // Number of groups of offset + + // Output Y shape (N, oC, oH, oW) + int64_t out_h{0}; // Output height (oH) + int64_t out_w{0}; // Output width (oW) + + bool use_mask{false}; // Whether optional mask input is provided +}; + +// Validates inputs and parses attributes into params. +// Returns Status::OK() on success; on failure, params may be partially filled. +inline Status DeformConvValidateAndParse( + const DeformConvAttributes& attrs, + const TensorShape& X_shape, + const TensorShape& W_shape, + const TensorShape& offset_shape, + const TensorShape* B_shape, + const TensorShape* mask_shape, + DeformConvParams& params) { + ORT_RETURN_IF_NOT(X_shape.NumDimensions() == 4, "Input X must be 4D (N, C, H, W)."); + ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); + ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); + + // Parse input shapes + params.N = X_shape[0]; + params.C = X_shape[1]; + params.H = X_shape[2]; + params.W_in = X_shape[3]; + params.M = W_shape[0]; + ORT_RETURN_IF_NOT(params.N >= 0, "Batch size N must be non-negative."); + ORT_RETURN_IF_NOT(params.C > 0, "Input channels C must be positive."); + ORT_RETURN_IF_NOT(params.M > 0, "Output channels M (oC) must be positive."); + ORT_RETURN_IF_NOT(W_shape[1] > 0, "Weight W must have positive in-channels (W_shape[1] = C/group)."); + + // Handle kernel shape inference. If kernel_shape is provided, it must match weight spatial dims + // to avoid GEMM using wrong K and potential out-of-bounds reads from the weight buffer. + const int64_t W_kH = W_shape[2]; + const int64_t W_kW = W_shape[3]; + if (!attrs.kernel_shape.empty()) { + ORT_RETURN_IF_NOT(attrs.kernel_shape.size() == 2, + "kernel_shape must be absent or have exactly 2 values (kH, kW) for 2D DeformConv."); + ORT_RETURN_IF_NOT(attrs.kernel_shape[0] == W_kH && attrs.kernel_shape[1] == W_kW, + "kernel_shape must match weight spatial dimensions (W_shape[2], W_shape[3])."); + params.kH = attrs.kernel_shape[0]; + params.kW = attrs.kernel_shape[1]; + } else { + params.kH = W_kH; + params.kW = W_kW; + } + + // DeformConv is 2D-only: when an attribute is present, require exact length to avoid silently misinterpreting malformed models. + params.pad_h = params.pad_w = params.pad_h_end = params.pad_w_end = 0; + if (!attrs.pads.empty()) { + ORT_RETURN_IF_NOT(attrs.pads.size() == 4, + "pads must be absent or have exactly 4 values [pad_h_begin, pad_w_begin, pad_h_end, pad_w_end] for 2D DeformConv."); + params.pad_h = attrs.pads[0]; + params.pad_w = attrs.pads[1]; + params.pad_h_end = attrs.pads[2]; + params.pad_w_end = attrs.pads[3]; + ORT_RETURN_IF_NOT(params.pad_h >= 0 && params.pad_w >= 0 && params.pad_h_end >= 0 && params.pad_w_end >= 0, + "Pads must be non-negative (ONNX spec)."); + } + + if (!attrs.strides.empty()) { + ORT_RETURN_IF_NOT(attrs.strides.size() == 2, + "strides must be absent or have exactly 2 values [stride_h, stride_w] for 2D DeformConv."); + params.stride_h = attrs.strides[0]; + params.stride_w = attrs.strides[1]; + } else { + params.stride_h = params.stride_w = 1; + } + + if (!attrs.dilations.empty()) { + ORT_RETURN_IF_NOT(attrs.dilations.size() == 2, + "dilations must be absent or have exactly 2 values [dilation_h, dilation_w] for 2D DeformConv."); + params.dilation_h = attrs.dilations[0]; + params.dilation_w = attrs.dilations[1]; + } else { + params.dilation_h = params.dilation_w = 1; + } + params.group = attrs.group; + params.offset_group = attrs.offset_group; + params.use_mask = (mask_shape != nullptr); + + // Validate attributes + ORT_RETURN_IF_NOT(params.stride_h > 0 && params.stride_w > 0, "Strides must be positive."); + ORT_RETURN_IF_NOT(params.dilation_h > 0 && params.dilation_w > 0, "Dilations must be positive."); + ORT_RETURN_IF_NOT(params.kH > 0 && params.kW > 0, "Kernel shape must be positive."); + ORT_RETURN_IF_NOT(params.group > 0, "group must be positive"); + ORT_RETURN_IF_NOT(params.offset_group > 0, "offset_group must be positive"); + + params.out_h = (params.H + params.pad_h + params.pad_h_end - params.dilation_h * (params.kH - 1) - 1) / params.stride_h + 1; + params.out_w = (params.W_in + params.pad_w + params.pad_w_end - params.dilation_w * (params.kW - 1) - 1) / params.stride_w + 1; + ORT_RETURN_IF_NOT(params.out_h >= 0 && params.out_w >= 0, "Computed output spatial size must be non-negative."); + + // CPU BilinearInterpolate uses int for indices (for performance optimization); W <= INT_MAX / (H+1) covers all index math. + ORT_RETURN_IF_NOT(params.H >= 0 && params.W_in >= 0, "Input spatial dimensions H and W must be non-negative."); + ORT_RETURN_IF_NOT(params.W_in <= static_cast(INT_MAX) / (params.H + 1), + "Input (H+1)*W must not exceed INT_MAX (for performance optimization)."); + + // Validate tensor shapes (use division to avoid int64 overflow in offset_group * 2 * kH * kW). + ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size."); + const int64_t offset_block = 2 * params.kH * params.kW; + ORT_RETURN_IF_NOT(offset_block > 0 && offset_shape[1] % offset_block == 0 && + offset_shape[1] / offset_block == params.offset_group, + "Offset channel count must be offset_group * 2 * kH * kW."); + ORT_RETURN_IF_NOT(offset_shape[2] == params.out_h, "Offset spatial height must match output oH."); + ORT_RETURN_IF_NOT(offset_shape[3] == params.out_w, "Offset spatial width must match output oW."); + ORT_RETURN_IF_NOT(params.C % params.offset_group == 0, "Input channels must be divisible by offset_group."); + ORT_RETURN_IF_NOT(params.C == W_shape[1] * params.group, "Input channels must match weight in channels * group."); + ORT_RETURN_IF_NOT(params.M % params.group == 0, "Output channels must be divisible by group."); + + if (B_shape != nullptr) { + ORT_RETURN_IF_NOT(B_shape->NumDimensions() == 1, "Bias B must be 1D."); + ORT_RETURN_IF_NOT((*B_shape)[0] == params.M, "Bias B must have shape [M] (M = number of output channels)."); + } + + // Validate mask if present + if (params.use_mask) { + ORT_RETURN_IF_NOT(mask_shape->NumDimensions() == 4, "Mask must be 4D."); + ORT_RETURN_IF_NOT((*mask_shape)[0] == params.N, "Mask batch size must match input batch size."); + const int64_t mask_block = params.kH * params.kW; + ORT_RETURN_IF_NOT(mask_block > 0 && (*mask_shape)[1] % mask_block == 0 && + (*mask_shape)[1] / mask_block == params.offset_group, + "Mask channel count must be offset_group * kH * kW."); + ORT_RETURN_IF_NOT((*mask_shape)[2] == params.out_h, "Mask spatial height must match output oH."); + ORT_RETURN_IF_NOT((*mask_shape)[3] == params.out_w, "Mask spatial width must match output oW."); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.h b/onnxruntime/core/providers/cpu/object_detection/roialign.h index bb97de158369b..4ce4825e1d78c 100644 --- a/onnxruntime/core/providers/cpu/object_detection/roialign.h +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.h @@ -129,6 +129,10 @@ class RoiAlignBase { std::string coordinate_transformation_mode; if (info.template GetAttr("coordinate_transformation_mode", &coordinate_transformation_mode).IsOK()) { half_pixel_ = coordinate_transformation_mode == "half_pixel"; + } else { + // For opset 16+, the default is "half_pixel" per ONNX spec. + // For opset 10 (which has no coordinate_transformation_mode attribute), false is correct. + half_pixel_ = info.node().SinceVersion() >= 16; } if (mode_ == RoiAlignMode::max && sampling_ratio_ != 1) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b87cf8cbc16c1..4c735fa2d5650 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -944,8 +944,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, ReverseSequence); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, float, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, MLFloat16, RoiAlign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, ThresholdedRelu); @@ -1441,10 +1444,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, bool, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Resize); @@ -1452,9 +1455,16 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, uint8_t, Resize); // Opset 19 +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, bool, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, DeformConv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Cast); @@ -1573,6 +1583,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, bool, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, ConstantOfShape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Identity); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, If); @@ -1596,6 +1610,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, DeformConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSigmoid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSigmoid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSigmoid); @@ -1608,6 +1626,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GRU); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, GRU); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, GRU); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, RoiAlign); // Opset 23. class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention); @@ -1639,6 +1661,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E5M2, Cast); #endif +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Transpose); @@ -1663,10 +1689,18 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, bool, Pad); // Opset 25. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Squeeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Unsqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, ConstantOfShape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, If); @@ -2063,8 +2097,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2560,10 +2597,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2571,9 +2608,16 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 19-20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2641,6 +2685,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 21 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // TODO(fajin): support other quantized types BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2706,6 +2754,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2728,8 +2780,16 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 23 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2782,10 +2842,18 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 25 BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc new file mode 100644 index 0000000000000..7a0b896acfe01 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -0,0 +1,331 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CUDA implementation of DeformConv (deformable convolution 2D). + +#include "core/providers/shared_library/provider_api.h" +#include "deform_conv.h" +#include "deform_conv_impl.h" + +#include + +#include "core/common/narrow.h" +#include "core/common/span_utils.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace cuda { + +namespace { + +constexpr int kMaxParallelImgs = 32; + +// Returns the greatest divisor of n that is <= bound. Used to choose uniform batch chunk sizes. +// Fast path: if n % bound == 0 (common for batch 32/64/128), return immediately. +// When n >= bound^2, linear scan from bound down is O(bound). Otherwise divisor enumeration +// from 1 to sqrt(n) is O(sqrt(n)). Uses integer comparison (no sqrt) for branch decision. +int GetGreatestDivisorBelowBound(int n, int bound) { + if (bound <= 0 || n <= 0) return 1; + if (n % bound == 0) return bound; // Fast path: batch is multiple of target + + // n >= bound^2 <=> bound <= sqrt(n) => linear scan is cheaper + if (static_cast(n) >= static_cast(bound) * bound) { + for (int k = bound - 1; k > 1; --k) { + if (n % k == 0) return k; + } + } else { + // n < bound^2 <=> bound > sqrt(n) => divisor enumeration is cheaper + int best = 1; + for (int i = 1; static_cast(i) * i <= static_cast(n); ++i) { + if (n % i != 0) continue; + const int q = n / i; + if (q <= bound && q > best) best = q; + if (i <= bound && i > best) best = i; + } + return best; + } + return 1; +} + +// Returns the maximum temp memory (bytes) allowed for DeformConv's im2col + GEMM buffers. +// Uses a fraction of total GPU memory to avoid OOM while leaving room for weights, activations, +// and other ops. No CUDA API is called; total_global_mem is expected from cached device props. +// +// Formula: +// budget = total_global_mem * kFraction +// return clamp(budget, kMin, kMax) +// with kFraction = 0.1 (10%), kMin = 32 MiB, kMax = 2 GiB. +// +// Example results (effective_max_temp after clamp): +// GPU | totalGlobalMem | effective_max_temp +// -----------------|----------------|-------------------- +// A100 80GB | 80 GiB | 2 GiB (capped) +// RTX 5080 16GB | 16 GiB | 1.6 GiB +// RTX 4090 24GB | 24 GiB | 2 GiB (capped) +// RTX 3080 10GB | 10 GiB | 1 GiB +// GTX 1060 6GB | 6 GiB | 614.4 MiB +// GTX 1050 4GB | 4 GiB | 409.6 MiB +// Jetson 2GB | 2 GiB | 204.8 MiB +size_t GetDeformConvEffectiveMaxTempBytes(size_t total_global_mem) { + constexpr double kFraction = 0.1; + constexpr size_t kMin = 32ULL * 1024 * 1024; + constexpr size_t kMax = 2ULL * 1024 * 1024 * 1024; + size_t budget = static_cast(static_cast(total_global_mem) * kFraction); + return std::clamp(budget, kMin, kMax); +} + +// Returns how many images to process in parallel per batch chunk for DeformConv. +// Chooses the largest divisor of batch size N that fits in the temp budget and does not +// exceed kMaxParallelImgs, so that batch dimension is split evenly (no remainder). +// Note: if N is prime and N > target_parallel_imgs, the greatest divisor <= target_parallel_imgs is 1, +// so batching is effectively disabled (single-image chunks). +// +// Formulas: +// kernel_size = kH * kW +// output_image_size = out_h * out_w +// bytes_per_image = output_image_size * (C * kernel_size + M / group) * sizeof(T) +// (temp bytes per image: im2col col buffer + GEMM output buffer per output position) +// max_parallel_imgs_mem = max(1, floor(effective_max_temp / bytes_per_image)) +// target_parallel_imgs = min(kMaxParallelImgs, max_parallel_imgs_mem) +// return GetGreatestDivisorBelowBound(N, target_parallel_imgs) +template +int GetNParallelImgs(const DeformConvParams& params, size_t total_global_mem) { + const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(total_global_mem); + const int64_t kernel_size = params.kH * params.kW; + const int64_t output_image_size = params.out_h * params.out_w; + const size_t bytes_per_image = SafeInt(output_image_size) * (params.C * kernel_size + params.M / params.group) * sizeof(T); + const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image))); + const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); + return GetGreatestDivisorBelowBound(static_cast(params.N), target_parallel_imgs); +} + +} // namespace + +template +Status DeformConv::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto* offset = context->Input(2); + const auto* B = context->Input(3); + const auto* mask = context->Input(4); + + DeformConvParams params; + ORT_RETURN_IF_ERROR(DeformConvValidateAndParse( + attrs_, + X->Shape(), + W->Shape(), + offset->Shape(), + B ? &B->Shape() : nullptr, + mask ? &mask->Shape() : nullptr, + params)); + + const int64_t N = params.N; + const int64_t C = params.C; + const int64_t H = params.H; + const int64_t W_in = params.W_in; + const int64_t M = params.M; + const int64_t kH = params.kH; + const int64_t kW = params.kW; + const int64_t pad_h = params.pad_h; + const int64_t pad_w = params.pad_w; + const int64_t stride_h = params.stride_h; + const int64_t stride_w = params.stride_w; + const int64_t dilation_h = params.dilation_h; + const int64_t dilation_w = params.dilation_w; + const int64_t group = params.group; + const int64_t offset_group = params.offset_group; + const int64_t out_h = params.out_h; + const int64_t out_w = params.out_w; + const bool use_mask = params.use_mask; + + Tensor* Y = context->Output(0, {N, M, out_h, out_w}); + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + const int n_parallel_imgs = GetNParallelImgs(params, GetDeviceProp().totalGlobalMem); + + const int64_t kernel_size = kH * kW; + const int64_t output_image_size = out_h * out_w; + const int64_t input_image_size = H * W_in; + const int64_t kernel_dim = (C / group) * kernel_size; + + const int64_t col_stride = static_cast(n_parallel_imgs) * output_image_size; + const int64_t col_buffer_size = (C * kernel_size) * col_stride; + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size)); + // Removed col_transposed allocation as we avoid physical transpose. + auto gemm_output_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt((M / group) * col_stride)); + + const T* Xdata = X->Data(); + const T* Wdata = W->Data(); + const T* offset_data = offset->Data(); + const T* mask_data = use_mask ? mask->Data() : nullptr; + T* Ydata = Y->MutableData(); + const T* Bdata = (B != nullptr) ? B->Data() : nullptr; + + cudaStream_t stream = Stream(context); + cublasHandle_t cublas = GetCublasHandle(context); + const cudaDeviceProp& device_prop = GetDeviceProp(); + CudaT alpha = ToCudaType::FromFloat(1.0f); + CudaT beta = ToCudaType::FromFloat(0.0f); + + for (int64_t b = 0; b < N; b += n_parallel_imgs) { + const int cur_parallel = static_cast(std::min(static_cast(n_parallel_imgs), N - b)); + const int64_t cur_out_size = static_cast(cur_parallel) * output_image_size; + + const T* X_block = Xdata + b * (C * input_image_size); + const T* offset_block = offset_data + b * (offset_group * 2 * kernel_size * output_image_size); + const T* mask_block = use_mask ? (mask_data + b * (offset_group * kernel_size * output_image_size)) : nullptr; + + ORT_RETURN_IF_ERROR(DeformConvIm2ColImpl( + stream, + X_block, + offset_block, + mask_block, + col_buffer.get(), + cur_parallel, + C, + H, + W_in, + kH, + kW, + out_h, + out_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + offset_group, + use_mask)); + + // GEMM layout trick: compute Y = W * Col without physical transpose. + // + // Our data is row-major: W [M/group, kernel_dim], Col [kernel_dim, cur_out_size], Y [M/group, cur_out_size]. + // cuBLAS is column-major. Key insight: row-major A[M,K] in memory equals column-major A^T[K,M]. + // We compute Y^T = Col^T * W^T by passing Col as A and W as B, both OP_N (no transpose): + // - Col (row [kernel_dim, cur_out_size]) -> cuBLAS interprets as col-major [cur_out_size, kernel_dim] = Col^T + // - W (row [M/group, kernel_dim]) -> cuBLAS interprets as col-major [kernel_dim, M/group] = W^T + // - C = A*B = Col^T * W^T = (W*Col)^T = Y^T; C is col-major [cur_out_size, M/group] = Y in row-major + // + // m=cur_out_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim, ldc=cur_out_size. + // + // cur_parallel==1: cur_out_size==output_image_size, C layout (pos, channel) matches NCHW Y_g[0,ch,pos] -> write + // directly into Y_g. Use strided batched for all groups in one call. + // cur_parallel>1: layouts differ -> write to gemm_output_buffer, then DeformConvCopyGemmOutputRowMajorToNCHW. + + const bool gemm_writes_directly = (cur_parallel == 1); + if (gemm_writes_directly) { + // Strided batched: one call for all groups. Strides between batches: + const int64_t stride_col = kernel_dim * col_stride; // = kernel_dim * output_image_size when cur_parallel==1 + const int64_t stride_weight = (M / group) * kernel_dim; + const int64_t stride_y = (M / group) * output_image_size; + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + narrow(output_image_size), + narrow(M / group), + narrow(kernel_dim), + &alpha, + reinterpret_cast(col_buffer.get()), + narrow(output_image_size), + stride_col, + reinterpret_cast(Wdata), + narrow(kernel_dim), + stride_weight, + &beta, + reinterpret_cast(Ydata + b * M * output_image_size), + narrow(output_image_size), + stride_y, + narrow(group), + device_prop, + UseTF32())); + } else { + // cur_parallel>1: GEMM output layout differs from NCHW; write to buffer then copy per group. + for (int64_t g = 0; g < group; ++g) { + const T* W_g = Wdata + g * (M / group) * kernel_dim; + const T* col_g = col_buffer.get() + g * kernel_dim * col_stride; + T* Y_g = Ydata + b * M * output_image_size + g * (M / group) * output_image_size; + + CUBLAS_RETURN_IF_ERROR((cublasGemmHelper( + cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + narrow(cur_out_size), + narrow(M / group), + narrow(kernel_dim), + &alpha, + reinterpret_cast(col_g), + narrow(cur_out_size), + reinterpret_cast(W_g), + narrow(kernel_dim), + &beta, + reinterpret_cast(gemm_output_buffer.get()), + narrow(cur_out_size), + device_prop, + UseTF32()))); + + ORT_RETURN_IF_ERROR(DeformConvCopyGemmOutputRowMajorToNCHW( + stream, + gemm_output_buffer.get(), + Y_g, + M, + M / group, + output_image_size, + cur_parallel)); + } + } + } + + if (Bdata != nullptr) { + ORT_RETURN_IF_ERROR(DeformConvAddBiasImpl(stream, Ydata, Bdata, N, M, out_h, out_w)); + } + + return Status::OK(); +} + +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + DeformConv, \ + kOnnxDomain, \ + 19, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DeformConv, \ + kOnnxDomain, \ + 22, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv) + +REGISTER_DEFORMCONV_KERNEL_TYPED(float); +REGISTER_DEFORMCONV_KERNEL_TYPED(double); +REGISTER_DEFORMCONV_KERNEL_TYPED(MLFloat16); + +// BFloat16 only for opset 22; opset 19-21 do not support BFloat16. +ONNX_OPERATOR_TYPED_KERNEL_EX( + DeformConv, + kOnnxDomain, + 22, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + DeformConv); + +#undef REGISTER_DEFORMCONV_KERNEL_TYPED + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.h b/onnxruntime/core/providers/cuda/nn/deform_conv.h new file mode 100644 index 0000000000000..fa564641d4b98 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/deform_conv_attributes.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace cuda { + +template +class DeformConv final : public CudaKernel { + public: + explicit DeformConv(const OpKernelInfo& info) : CudaKernel(info), attrs_(info) {} + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + DeformConvAttributes attrs_; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeformConv); +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu new file mode 100644 index 0000000000000..7b3666fca810b --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -0,0 +1,512 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CUDA implementation of DeformConv: deformable im2col kernel + bilinear interpolation. +// Reference: torchvision deform_conv2d_kernel.cu, ONNX DeformConv spec. + +#include "deform_conv_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/fast_divmod.h" +#include "core/common/float16.h" +#include +#include +#include + +namespace onnxruntime { +namespace cuda { + +namespace { + +constexpr int kDeformConvThreadsPerBlock = 256; + +template +struct DeformConvKSize { + static constexpr int value = N; +}; + +// Calculate grid size with a safety limit to prevent overflow. +// Since we use grid-stride loops in kernels, limiting the grid size is safe. +inline int GetGridSize(size_t n, size_t threads_per_block) { + size_t blocks_needed = (n + threads_per_block - 1) / threads_per_block; + return static_cast(std::min(blocks_needed, static_cast(std::numeric_limits::max()))); +} + +// __ldg has no overload for BFloat16*; use 16-bit load + FromBits. Other types use __ldg directly. +template +__device__ __inline__ T DeformConvLdg(const T* p) { + return __ldg(p); +} +template <> +__device__ __inline__ BFloat16 DeformConvLdg(const BFloat16* p) { + return BFloat16::FromBits(__ldg(reinterpret_cast(p))); +} + +// Traits for bilinear interpolation math: +// - ComputeT: type used for coordinate/weight math (float for half/BFloat16, T otherwise) +// - Load: load one element and convert to ComputeT +// - ToResult: convert ComputeT result back to T +// - Zero: zero value of T +template +struct DeformConvBilinearTraits { + using ComputeT = T; + + __device__ static __inline__ ComputeT Load(const T* p) { + return __ldg(p); + } + + __device__ static __inline__ T ToResult(ComputeT v) { + return v; + } + + __device__ static __inline__ T Zero() { + return static_cast(0); + } +}; + +template <> +struct DeformConvBilinearTraits { + using ComputeT = float; + + __device__ static __inline__ ComputeT Load(const half* p) { + return __half2float(__ldg(p)); + } + + __device__ static __inline__ half ToResult(ComputeT v) { + return __float2half(v); + } + + __device__ static __inline__ half Zero() { + return __float2half(0.0f); + } +}; + +template <> +struct DeformConvBilinearTraits { + using ComputeT = float; + + __device__ static __inline__ ComputeT Load(const BFloat16* p) { + return static_cast(DeformConvLdg(p)); + } + + __device__ static __inline__ BFloat16 ToResult(ComputeT v) { + return BFloat16(v); + } + + __device__ static __inline__ BFloat16 Zero() { + return BFloat16(0.0f); + } +}; + +// Bilinear interpolation at (h, w). Returns 0 if out of bounds (ONNX spec). +// Indices h_low, w_low, h_high, w_high use int (not int64_t) to reduce register pressure and +// improve occupancy in the hot path. Limitation: (H+1)*W must not exceed INT_MAX; this is +// validated on the host side in DeformConvValidateAndParse to guarantee index math in int +// does not overflow. For half/BFloat16, coordinate and weight math use float via +// DeformConvBilinearTraits to avoid precision loss. We keep floor() results in CoordT and +// cast to int only for indices (h_low/w_low), which avoids unnecessary CoordT->int->CoordT +// round trips when computing lh/lw/hh/hw. +template +__device__ __inline__ T BilinearInterpolate( + const T* in, + int height, + int width, + typename DeformConvBilinearTraits::ComputeT h, + typename DeformConvBilinearTraits::ComputeT w) { + using Traits = DeformConvBilinearTraits; + using CoordT = typename Traits::ComputeT; + + // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case). + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return Traits::Zero(); + } + + // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw. + CoordT h_floor = _Floor(h); + CoordT w_floor = _Floor(w); + int h_low = static_cast(h_floor); + int w_low = static_cast(w_floor); + int h_high = h_low + 1; + int w_high = w_low + 1; + + CoordT lh = h - h_floor; + CoordT lw = w - w_floor; + CoordT hh = static_cast(1) - lh; + CoordT hw = static_cast(1) - lw; + + // [Optimization 3]: Avoid a second multiply for base_high. + // Original code computed both bases as: + // base_low = h_low * width; + // base_high = h_high * width; + // Since h_high = h_low + 1, we can rewrite base_high as base_low + width and + // save one integer multiply in the hot path: + // base_low = h_low * width; + // base_high = base_low + width; + int base_low = h_low * width; + int base_high = base_low + width; + + CoordT v1 = (h_low >= 0 && w_low >= 0) ? Traits::Load(in + base_low + w_low) : static_cast(0); + CoordT v2 = (h_low >= 0 && w_high < width) ? Traits::Load(in + base_low + w_high) : static_cast(0); + CoordT v3 = (h_high < height && w_low >= 0) ? Traits::Load(in + base_high + w_low) : static_cast(0); + CoordT v4 = (h_high < height && w_high < width) ? Traits::Load(in + base_high + w_high) : static_cast(0); + + CoordT w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + return Traits::ToResult(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); +} + +// kH/kW = -1 means dynamic (runtime); >= 0 means compile-time constant for loop unrolling. +template +__global__ void DeformableIm2ColKernel( + IndexT num_kernels, + const T* __restrict__ input, + const T* __restrict__ offset, + const T* __restrict__ mask, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t channels, + int64_t offset_group, + DivMod out_h_div, + DivMod out_w_div, + DivMod parallel_imgs_div, + DivMod channel_per_offset_grp_div, + bool use_mask, + T* __restrict__ data_col) { + constexpr bool is_fixed = (kH >= 0 && kW >= 0); + const int64_t h_dim = is_fixed ? kH : weight_h; + const int64_t w_dim = is_fixed ? kW : weight_w; + + // Reconstruct dimensions from DivMod objects + const int64_t out_h = out_h_div.d_; + const int64_t out_w = out_w_div.d_; + const int64_t parallel_imgs = parallel_imgs_div.d_; + + const int64_t out_size = out_h * out_w; + // The stride for data_col is (parallel_imgs * out_h * out_w) + const int64_t col_stride = parallel_imgs * out_size; + + using CoordT = typename DeformConvBilinearTraits::ComputeT; + + for (IndexT index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { + IndexT val = index; + IndexT out_x, out_y, out_b, in_c; + + // Fast division/modulo to recover coordinates + out_w_div.divmod(val, val, out_x); + out_h_div.divmod(val, val, out_y); + parallel_imgs_div.divmod(val, in_c, out_b); + + // [Optimization 3] Avoid expensive division if offset_group is 1 (very common case). + IndexT offset_grp = 0; + if (offset_group > 1) { + IndexT dummy; + channel_per_offset_grp_div.divmod(in_c, offset_grp, dummy); + } + + // [Optimization 2] Common Subexpression Elimination (CSE) & Pointer Arithmetic + // Pre-calculate base pointers to reduce integer arithmetic inside the inner loops. + + // 1. Input pointer base for this batch and channel. + const T* input_ptr = input + static_cast(out_b) * (channels * height * width) + static_cast(in_c) * (height * width); + + // 2. Spatial index in the output feature map. + const int64_t spatial_idx = static_cast(out_y) * out_w + static_cast(out_x); + + // 3. Offset pointer base calculation. + // Layout: (N, offset_groups, 2*KH*KW, OH, OW) + // We pre-calculate the pointer to the start of the specific (n, g) block, plus spatial_idx. + const int64_t offset_group_block_size = 2 * h_dim * w_dim * out_size; + const T* offset_ptr_base = offset + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * offset_group_block_size + spatial_idx; + + // 4. Mask pointer base calculation (if used). + // Layout: (N, offset_groups, KH*KW, OH, OW) + const T* mask_ptr_base = nullptr; + if (use_mask) { + const int64_t mask_group_block_size = h_dim * w_dim * out_size; + mask_ptr_base = mask + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * mask_group_block_size + spatial_idx; + } + + // 5. Output pointer base calculation. + // data_col Layout: (C * KH * KW, N * OH * OW) + // The current thread writes to the column `c_col` = (b * OH * OW) + spatial_idx. + // The starting row for this channel is `in_c * KH * KW`. + const int64_t c_col = static_cast(out_b) * out_size + spatial_idx; + T* data_col_ptr_base = data_col + (static_cast(in_c) * h_dim * w_dim) * col_stride + c_col; + + // 6. Pre-calculate invariant coordinate parts. + // Use float for coordinate math when T is half or BFloat16 to avoid precision loss. + const CoordT base_h_im = static_cast(out_y * stride_h - pad_h); + const CoordT base_w_im = static_cast(out_x * stride_w - pad_w); + + auto process_kernel_point = [&](int64_t i, int64_t j) { + const int64_t kernel_idx = i * w_dim + j; + T mask_val = static_cast(1); + if (use_mask) { + // Access mask using pre-calculated base and stride. + mask_val = DeformConvLdg(mask_ptr_base + kernel_idx * out_size); + } + + // Calculate offset pointers relative to the base. + // The offset tensor stores (y_offset, x_offset) pairs for each kernel weight. + // Stride between y_offset and x_offset is `out_size`. + const int64_t offset_offset_idx = (2 * kernel_idx) * out_size; + + const CoordT offset_h = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx)); + const CoordT offset_w = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx + out_size)); + + const CoordT h_im = base_h_im + static_cast(i * dilation_h) + offset_h; + const CoordT w_im = base_w_im + static_cast(j * dilation_w) + offset_w; + + // height/width are validated on host (DeformConvValidateAndParse) so int is safe here. + T val = BilinearInterpolate(input_ptr, + static_cast(height), + static_cast(width), + h_im, + w_im); + + // Match CPU path: always interpolate then apply mask to keep branch-free hot loop. + data_col_ptr_base[kernel_idx * col_stride] = val * mask_val; + }; + + if constexpr (is_fixed) { +#pragma unroll + for (int i = 0; i < kH; ++i) { +#pragma unroll + for (int j = 0; j < kW; ++j) { + process_kernel_point(i, j); + } + } + } else { + for (int64_t i = 0; i < weight_h; ++i) { + for (int64_t j = 0; j < weight_w; ++j) { + process_kernel_point(i, j); + } + } + } + } +} + +// Bias add: Y[n,m,oh,ow] += B[m]. Layout NCHW. +template +__global__ void DeformConvAddBiasKernel( + T* Y, + const T* B, + DivMod spatial_div, // For dividing by (H * W) + DivMod channel_div, // For dividing by M (channel count) + int64_t total_elements) { + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += blockDim.x * gridDim.x) { + int64_t val = idx; + int64_t batch_channel_idx, pixel_idx; + + // 1. First decomposition: decompose idx into (batch_channel_idx, pixel_idx) + // Equivalent to: batch_channel_idx = idx / (H*W); pixel_idx = idx % (H*W); + spatial_div.divmod(val, batch_channel_idx, pixel_idx); + + int64_t batch_idx, channel_idx; + + // 2. Second decomposition: decompose batch_channel_idx into (batch_idx, channel_idx) + // Equivalent to: channel_idx = batch_channel_idx % M; + // We only need channel_idx (i.e. m) + channel_div.divmod(batch_channel_idx, batch_idx, channel_idx); + (void)batch_idx; // Only channel_idx is needed + + // channel_idx is what we need (i.e. m) + Y[idx] += DeformConvLdg(B + channel_idx); + } +} + +// Copy GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) into NCHW Y_g. +// src(c, j) with j = b_idx*output_image_size + pos -> dst[b_idx*M*output_image_size + c*output_image_size + pos]. +template +__global__ void CopyGemmOutputRowMajorToNCHWKernel( + const T* __restrict__ src, + T* __restrict__ dst, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel) { + int64_t total = cur_parallel * M_per_group * output_image_size; + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { + int64_t pos = idx % output_image_size; + int64_t c = (idx / output_image_size) % M_per_group; + int64_t b_idx = idx / (output_image_size * M_per_group); + int64_t j = b_idx * output_image_size + pos; + // src index for row-major: c * (cur_parallel * output_image_size) + j + dst[b_idx * M * output_image_size + c * output_image_size + pos] = src[c * (cur_parallel * output_image_size) + j]; + } +} + +} // namespace + +template +Status DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) { + int64_t total = N * M * out_h * out_w; + if (total <= 0) return Status::OK(); + + // 1. Prepare divisor + int64_t out_size = out_h * out_w; + + // 2. Create FastDivMod object (note: ensure int64_t version of DivMod is used here) + DivMod spatial_div(out_size); + DivMod channel_div(M); + + int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); + + // 3. Pass DivMod objects + DeformConvAddBiasKernel<<>>( + Y, + B, + spatial_div, + channel_div, + total); + return CUDA_CALL(cudaGetLastError()); +} + +template +Status DeformConvCopyGemmOutputRowMajorToNCHW( + cudaStream_t stream, + const T* gemm_output, + T* Y_g, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel) { + int64_t total = cur_parallel * M_per_group * output_image_size; + if (total <= 0) return Status::OK(); + int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); + CopyGemmOutputRowMajorToNCHWKernel<<>>( + gemm_output, Y_g, M, M_per_group, output_image_size, cur_parallel); + return CUDA_CALL(cudaGetLastError()); +} + +template +Status DeformConvIm2ColImpl( + cudaStream_t stream, + const T* input, + const T* offset, + const T* mask, + T* col_buffer, + int64_t parallel_imgs, + int64_t C, + int64_t H, + int64_t W, + int64_t kH, + int64_t kW, + int64_t out_h, + int64_t out_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t offset_group, + bool use_mask) { + const int64_t num_kernels = static_cast(C) * out_h * out_w * parallel_imgs; + if (num_kernels <= 0) { + return Status::OK(); + } + + const int64_t col_numel = static_cast(C) * kH * kW * parallel_imgs * out_h * out_w; + const bool use_64bit = (num_kernels > static_cast(std::numeric_limits::max())) || + (col_numel > static_cast(std::numeric_limits::max())); + + int blocks = GetGridSize(static_cast(num_kernels), kDeformConvThreadsPerBlock); + + auto launch = [&](auto kH_tag, auto kW_tag) { + constexpr int KH = decltype(kH_tag)::value; + constexpr int KW = decltype(kW_tag)::value; + if (use_64bit) { + DeformableIm2ColKernel<<>>( + num_kernels, input, offset, mask, H, W, kH, kW, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, C, offset_group, + DivMod(out_h), DivMod(out_w), DivMod(parallel_imgs), + DivMod(C / offset_group), use_mask, col_buffer); + } else { + DeformableIm2ColKernel<<>>( + static_cast(num_kernels), input, offset, mask, H, W, kH, kW, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, C, offset_group, + DivMod(static_cast(out_h)), + DivMod(static_cast(out_w)), + DivMod(static_cast(parallel_imgs)), + DivMod(static_cast(C / offset_group)), + use_mask, col_buffer); + } + }; + + if (kH == 1 && kW == 1) { + launch(DeformConvKSize<1>{}, DeformConvKSize<1>{}); + } else if (kH == 3 && kW == 3) { + launch(DeformConvKSize<3>{}, DeformConvKSize<3>{}); + } else if (kH == 5 && kW == 5) { + launch(DeformConvKSize<5>{}, DeformConvKSize<5>{}); + } else { + launch(DeformConvKSize<-1>{}, DeformConvKSize<-1>{}); + } + return CUDA_CALL(cudaGetLastError()); +} + +#define INST_DeformConvIm2ColImpl(T) \ + template Status DeformConvIm2ColImpl(cudaStream_t, const T*, const T*, const T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool) + +INST_DeformConvIm2ColImpl(float); +INST_DeformConvIm2ColImpl(double); +INST_DeformConvIm2ColImpl(half); +INST_DeformConvIm2ColImpl(BFloat16); + +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const double*, double*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const half*, half*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const BFloat16*, BFloat16*, int64_t, int64_t, int64_t, int64_t); + +template Status DeformConvAddBiasImpl(cudaStream_t, float*, const float*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, double*, const double*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, half*, const half*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFloat16*, int64_t, int64_t, int64_t, int64_t); + +// Delegate ORT type to CUDA type (e.g. MLFloat16 -> half); avoids repeating three identical specializations. +#define DELEGATE_DEFORM_CONV_IMPL(ORT_T, CUDA_T) \ + template <> \ + Status DeformConvIm2ColImpl(cudaStream_t stream, const ORT_T* input, \ + const ORT_T* offset, const ORT_T* mask, ORT_T* col_buffer, \ + int64_t parallel_imgs, int64_t C, int64_t H, int64_t W, \ + int64_t kH, int64_t kW, int64_t out_h, int64_t out_w, \ + int64_t pad_h, int64_t pad_w, int64_t stride_h, int64_t stride_w, \ + int64_t dilation_h, int64_t dilation_w, int64_t offset_group, bool use_mask) { \ + return DeformConvIm2ColImpl(stream, reinterpret_cast(input), \ + reinterpret_cast(offset), \ + mask ? reinterpret_cast(mask) : nullptr, \ + reinterpret_cast(col_buffer), \ + parallel_imgs, C, H, W, kH, kW, out_h, out_w, \ + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, \ + offset_group, use_mask); \ + } \ + template <> \ + Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ + const ORT_T* gemm_output, ORT_T* Y_g, \ + int64_t M, int64_t M_per_group, \ + int64_t output_image_size, int64_t cur_parallel) { \ + return DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ + reinterpret_cast(gemm_output), \ + reinterpret_cast(Y_g), \ + M, M_per_group, output_image_size, cur_parallel); \ + } \ + template <> \ + Status DeformConvAddBiasImpl(cudaStream_t stream, ORT_T * Y, const ORT_T* B, \ + int64_t N, int64_t M, int64_t out_h, int64_t out_w) { \ + return DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ + reinterpret_cast(B), N, M, out_h, out_w); \ + } + +// BFloat16 is not delegated: ORT's BFloat16 is the same type used in device code (ToCudaType in +// cuda_common.h), so the explicit instantiations above (INST_DeformConvIm2ColImpl(BFloat16), etc.) suffice. +DELEGATE_DEFORM_CONV_IMPL(MLFloat16, half) + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h new file mode 100644 index 0000000000000..0c26cb55311bc --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/status.h" + +namespace onnxruntime { +namespace cuda { + +// Adds bias to output: Y[n,m,oh,ow] += B[m]. Y is [N, M, out_h, out_w], B is [M]. +// T may be float, double, MLFloat16 (FP16), or BFloat16. +template +Status DeformConvAddBiasImpl( + cudaStream_t stream, + T* Y, + const T* B, + int64_t N, + int64_t M, + int64_t out_h, + int64_t out_w); + +// Copies GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) to NCHW slice at Y_g. +// T may be float, double, MLFloat16 (FP16), or BFloat16. +template +Status DeformConvCopyGemmOutputRowMajorToNCHW( + cudaStream_t stream, + const T* gemm_output, + T* Y_g, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel); + +// Fills col_buffer with deformable im2col. col_buffer layout: row-major [C*kH*kW, parallel_imgs*out_h*out_w]. +// Called once per batch block; caller does GEMM and bias. T may be float, double, MLFloat16 (FP16), or BFloat16. +template +Status DeformConvIm2ColImpl( + cudaStream_t stream, + const T* input, // [parallel_imgs, C, H, W] + const T* offset, // [parallel_imgs, offset_group*2*kH*kW, out_h, out_w] + const T* mask, // [parallel_imgs, offset_group*kH*kW, out_h, out_w] or nullptr + T* col_buffer, // [C*kH*kW, parallel_imgs*out_h*out_w] + int64_t parallel_imgs, + int64_t C, + int64_t H, + int64_t W, + int64_t kH, + int64_t kW, + int64_t out_h, + int64_t out_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t offset_group, + bool use_mask); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign.cc b/onnxruntime/core/providers/cuda/object_detection/roialign.cc index 71fb066c2898f..5d876ae5a2cc9 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cuda/object_detection/roialign.cc @@ -7,11 +7,37 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ +#define ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ RoiAlign, \ kOnnxDomain, \ 10, \ + 15, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + RoiAlign); + +#define ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + RoiAlign, \ + kOnnxDomain, \ + 16, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + RoiAlign); + +#define ADD_TYPED_ROIALIGN_OP_22(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RoiAlign, \ + kOnnxDomain, \ + 22, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -67,13 +93,22 @@ Status RoiAlign::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -#define SPECIALIZED_COMPUTE(T) \ - REGISTER_KERNEL_TYPED(T) \ +#define SPECIALIZED_COMPUTE(T) \ + ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \ + ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \ + ADD_TYPED_ROIALIGN_OP_22(T) \ template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; SPECIALIZED_COMPUTE(float) SPECIALIZED_COMPUTE(double) -// SPECIALIZED_COMPUTE(MLFloat16) +// MLFloat16 is available for RoiAlign op from version 16 (not version 10): +ADD_VERSIONED_TYPED_ROIALIGN_OP_16(MLFloat16) +ADD_TYPED_ROIALIGN_OP_22(MLFloat16) +template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; + +// BFloat16 is available for RoiAlign op from version 22: +ADD_TYPED_ROIALIGN_OP_22(BFloat16) +template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; } // namespace cuda }; // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu index 7acfd9d075461..87f4aba8e45b2 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu +++ b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu @@ -17,64 +17,72 @@ #include "roialign_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/accumulation_type.h" namespace onnxruntime { namespace cuda { template -__device__ T bilinear_interpolate( +__device__ AccumulationType_t bilinear_interpolate( const T* bottom_data, const int height, const int width, - T y, - T x, + AccumulationType_t y, + AccumulationType_t x, const bool is_mode_avg, const int index /* index for debug only*/) { + using TAcc = AccumulationType_t; + // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { + if (y < static_cast(-1.0f) || y > static_cast(height) || + x < static_cast(-1.0f) || x > static_cast(width)) { // empty - return 0; + return static_cast(0.0f); } - if (y <= 0) { - y = 0; + if (y <= static_cast(0.0f)) { + y = static_cast(0.0f); } - if (x <= 0) { - x = 0; + if (x <= static_cast(0.0f)) { + x = static_cast(0.0f); } - int y_low = (int)y; - int x_low = (int)x; + int y_low = static_cast(y); + int x_low = static_cast(x); int y_high; int x_high; if (y_low >= height - 1) { y_high = y_low = height - 1; - y = (T)y_low; + y = static_cast(y_low); } else { y_high = y_low + 1; } if (x_low >= width - 1) { x_high = x_low = width - 1; - x = (T)x_low; + x = static_cast(x_low); } else { x_high = x_low + 1; } - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; + TAcc ly = y - static_cast(y_low); + TAcc lx = x - static_cast(x_low); + TAcc hy = static_cast(1.0f) - ly; + TAcc hx = static_cast(1.0f) - lx; // do bilinear interpolation - T v1 = bottom_data[y_low * width + x_low]; - T v2 = bottom_data[y_low * width + x_high]; - T v3 = bottom_data[y_high * width + x_low]; - T v4 = bottom_data[y_high * width + x_high]; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + TAcc v1 = static_cast(bottom_data[y_low * width + x_low]); + TAcc v2 = static_cast(bottom_data[y_low * width + x_high]); + TAcc v3 = static_cast(bottom_data[y_high * width + x_low]); + TAcc v4 = static_cast(bottom_data[y_high * width + x_high]); + TAcc w1 = hy * hx; + TAcc w2 = hy * lx; + TAcc w3 = ly * hx; + TAcc w4 = ly * lx; - T val = is_mode_avg - ? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg - : max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max + TAcc val = is_mode_avg + ? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg + : max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max return val; } @@ -97,6 +105,8 @@ __global__ void RoIAlignForward( const bool half_pixel, const int64_t* batch_indices_ptr, const int64_t batch_size) { + using TAcc = AccumulationType_t; + for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -111,26 +121,27 @@ __global__ void RoIAlignForward( // If the index is out of range, we set the output to 0 for this RoI element. if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) { CUDA_KERNEL_ASSERT(false && "batch_indices values are out of range"); - top_data[index] = 0; + top_data[index] = static_cast(0.0f); continue; } // Do not using rounding; this implementation detail is critical - T roi_offset = half_pixel ? T(0.5) : T(0); - T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset; - T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset; - T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset; - T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; + const TAcc spatial_scale_acc = static_cast(spatial_scale); + const TAcc roi_offset = half_pixel ? static_cast(0.5f) : static_cast(0.0f); + TAcc roi_start_w = static_cast(offset_bottom_rois[0]) * spatial_scale_acc - roi_offset; + TAcc roi_start_h = static_cast(offset_bottom_rois[1]) * spatial_scale_acc - roi_offset; + TAcc roi_end_w = static_cast(offset_bottom_rois[2]) * spatial_scale_acc - roi_offset; + TAcc roi_end_h = static_cast(offset_bottom_rois[3]) * spatial_scale_acc - roi_offset; + + TAcc roi_width = roi_end_w - roi_start_w; + TAcc roi_height = roi_end_h - roi_start_h; if (!half_pixel) { // backward compatibility // Force malformed ROIs to be 1x1 - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); + roi_width = max(roi_width, static_cast(1.0f)); + roi_height = max(roi_height, static_cast(1.0f)); } - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + const TAcc bin_size_h = roi_height / static_cast(pooled_height); + const TAcc bin_size_w = roi_width / static_cast(pooled_width); const T* offset_bottom_data = bottom_data + static_cast((roi_batch_ind * channels + c) * height * width); @@ -138,26 +149,27 @@ __global__ void RoIAlignForward( // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio - : _Ceil(roi_height / pooled_height); // e.g., = 2 + : static_cast(_Ceil(roi_height / static_cast(pooled_height))); // e.g., = 2 int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : _Ceil(roi_width / pooled_width); + (sampling_ratio > 0) ? sampling_ratio : static_cast(_Ceil(roi_width / static_cast(pooled_width))); // We do average (integral) pooling inside a bin - const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + const int grid_count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + const TAcc count = static_cast(grid_count); // e.g. = 4 - T output_val = 0.; + TAcc output_val = static_cast(0.0f); bool max_flag = false; for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + const TAcc y = roi_start_h + static_cast(ph) * bin_size_h + + (static_cast(iy) + static_cast(0.5f)) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); + const TAcc x = roi_start_w + static_cast(pw) * bin_size_w + + (static_cast(ix) + static_cast(0.5f)) * bin_size_w / + static_cast(roi_bin_grid_w); - T val = bilinear_interpolate( + const TAcc val = bilinear_interpolate( offset_bottom_data, height, width, y, x, is_mode_avg, index); if (is_mode_avg) { @@ -176,7 +188,7 @@ __global__ void RoIAlignForward( output_val /= count; } - top_data[index] = output_val; + top_data[index] = static_cast(output_val); } } @@ -241,6 +253,8 @@ void RoiAlignImpl( SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double) +SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 9b23209953081..3dd50c1c03cbf 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -40,10 +40,70 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 18, 18, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 19, 20, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 21, 22, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 23, 23, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 24, 24, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Pad, \ kOnnxDomain, \ - 18, \ + 25, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -154,6 +214,11 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { effective_input_extents.push_back(extent); } + TArray input_offsets(dimension_count); + for (int32_t i = 0; i < dimension_count; ++i) { + input_offsets[i] = -(*p_slices)[i]; + } + TensorShape output_shape(output_dims); auto& output_tensor = *ctx->Output(0, output_shape); @@ -236,7 +301,8 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } - if (IsNCHWInputWithPaddingAlongHAndW(dimension_count, lower_pads, upper_pads)) { + if (mode_ != Mode::Wrap && + IsNCHWInputWithPaddingAlongHAndW(dimension_count, lower_pads, upper_pads)) { // If we have entered here, it means the input can only be 4-D (NCHW), 3-D (CHW), or 2-D (HW) // NCHW input @@ -282,6 +348,8 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { input_dims, input_strides, lower_pads, + TArray(effective_input_extents), + input_offsets, value, static_cast(mode_), reinterpret_cast::MappedType*>(input_tensor.Data()), diff --git a/onnxruntime/core/providers/cuda/tensor/pad_impl.cu b/onnxruntime/core/providers/cuda/tensor/pad_impl.cu index 6f530e800fdf2..6020769bf0ddf 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/pad_impl.cu @@ -7,19 +7,27 @@ namespace onnxruntime { namespace cuda { -// PadMode enum from core/providers/cpu/tensor/pad.h, cannot use that header because of nvcc/onnxruntime incompatibility +// PadMode enum from core/providers/cpu/tensor/padbase.h, cannot use that header because of nvcc/onnxruntime incompatibility enum class PadMode : int { Constant = 0, Reflect, - Edge + Edge, + Wrap }; +__device__ __forceinline__ int64_t WrapCoordinate(int64_t coord, int64_t extent) { + int64_t wrapped = coord % extent; + return wrapped < 0 ? wrapped + extent : wrapped; +} + template __global__ void _PadKernel( const size_t shape_rank, const TArray input_dims, const TArray input_strides, const TArray lower_pads, + const TArray effective_input_extents, + const TArray input_offsets, const T pad_value, const T* input_data, const TArray fdm_output_strides, @@ -33,33 +41,44 @@ __global__ void _PadKernel( int out_coord, r; fdm_output_strides[dim].divmod(output_index, out_coord, r); output_index = r; - int in_coord = 0; - if (out_coord < lower_pads[dim]) { - switch ((PadMode)pad_mode) { - case PadMode::Constant: - use_pad_value = true; - break; - case PadMode::Edge: - in_coord = 0; - break; - case PadMode::Reflect: - in_coord = lower_pads[dim] - out_coord; - break; - } - } else if (out_coord >= lower_pads[dim] + input_dims[dim]) { - switch ((PadMode)pad_mode) { - case PadMode::Constant: - use_pad_value = true; - break; - case PadMode::Edge: - in_coord = input_dims[dim] - 1; - break; - case PadMode::Reflect: - in_coord = input_dims[dim] - 2 - (out_coord - (lower_pads[dim] + input_dims[dim])); - break; - } + int64_t in_coord = 0; + if constexpr (pad_mode == static_cast(PadMode::Wrap)) { + const int64_t effective_input_extent = effective_input_extents[dim]; + const int64_t pre_pad = lower_pads[dim] + input_offsets[dim]; + const int64_t relative_coord = static_cast(out_coord) - pre_pad; + in_coord = input_offsets[dim] + WrapCoordinate(relative_coord, effective_input_extent); } else { - in_coord = out_coord - lower_pads[dim]; + if (out_coord < lower_pads[dim]) { + switch ((PadMode)pad_mode) { + case PadMode::Constant: + use_pad_value = true; + break; + case PadMode::Edge: + in_coord = 0; + break; + case PadMode::Reflect: + in_coord = lower_pads[dim] - out_coord; + break; + case PadMode::Wrap: + break; + } + } else if (out_coord >= lower_pads[dim] + input_dims[dim]) { + switch ((PadMode)pad_mode) { + case PadMode::Constant: + use_pad_value = true; + break; + case PadMode::Edge: + in_coord = input_dims[dim] - 1; + break; + case PadMode::Reflect: + in_coord = input_dims[dim] - 2 - (out_coord - (lower_pads[dim] + input_dims[dim])); + break; + case PadMode::Wrap: + break; + } + } else { + in_coord = out_coord - lower_pads[dim]; + } } input_index += input_strides[dim] * in_coord; } @@ -136,6 +155,8 @@ void PadImpl( const TArray& input_dims, const TArray& input_strides, const TArray& lower_pads, + const TArray& effective_input_extents, + const TArray& input_offsets, const T pad_value, const int pad_mode, const T* input_data, @@ -149,17 +170,22 @@ void PadImpl( switch (pad_mode) { case 0: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, pad_value, input_data, fdm_output_strides, output_data, N); break; case 1: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, pad_value, input_data, fdm_output_strides, output_data, N); break; case 2: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, + pad_value, input_data, fdm_output_strides, output_data, N); + break; + case 3: + _PadKernel<<>>( + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, pad_value, input_data, fdm_output_strides, output_data, N); break; } @@ -211,6 +237,8 @@ void PadNCHWInputWithPaddingAlongHAndWImpl( template void PadImpl(cudaStream_t stream, const size_t shape_rank, \ const TArray& input_dims, const TArray& input_strides, \ const TArray& lower_pads, \ + const TArray& effective_input_extents, \ + const TArray& input_offsets, \ const T pad_value, \ const int pad_mode, \ const T* input_data, \ diff --git a/onnxruntime/core/providers/cuda/tensor/pad_impl.h b/onnxruntime/core/providers/cuda/tensor/pad_impl.h index dc700ea2304e9..96f158dd187fc 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/pad_impl.h @@ -32,6 +32,8 @@ void PadImpl( const TArray& input_dims, const TArray& input_strides, const TArray& lower_pads, + const TArray& effective_input_extents, + const TArray& input_offsets, const T pad_value, const int pad_mode, const T* input_data, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index cce90f3ef82be..e20cc9140916a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -268,7 +268,7 @@ inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { const fs::path path{main_graph.ModelPath()}; if (path.has_filename()) { - const auto model_name = path.filename().string(); + const auto model_name = PathToUTF8String(path.filename().native()); LOGS_DEFAULT(INFO) << "Model name is '" << model_name << "'"; // Ensure enough characters are hashed in case model names are too short diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 7baac6aa1f6d0..0b137c4674b00 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -166,7 +166,7 @@ Status CreateCtxNode(const GraphViewer& graph_viewer, } attr_ep_cache_context->set_s(engine_data_str); } else { - std::string engine_cache_filename = std::filesystem::path(engine_cache_path).filename().string(); + std::string engine_cache_filename = PathToUTF8String(std::filesystem::path(engine_cache_path).filename().native()); attr_ep_cache_context->set_s(engine_cache_filename); std::fstream engine_cache_file(engine_cache_path, std::ios::binary | std::ios::out); if (engine_cache_file.is_open()) { @@ -188,7 +188,7 @@ Status CreateCtxNode(const GraphViewer& graph_viewer, attr_onnx_filename->set_name(ONNX_MODEL_FILENAME); attr_onnx_filename->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_onnx_filename->set_s(std::filesystem::path(onnx_model_path).filename().string()); + attr_onnx_filename->set_s(PathToUTF8String(std::filesystem::path(onnx_model_path).filename().native())); attr_sdk_version->set_name(SDK_VERSION); attr_sdk_version->set_type(onnx::AttributeProto_AttributeType_STRING); diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 8c8e1879a2c6b..1d17a72641a4e 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -185,7 +185,7 @@ void BackendManager::TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVi model_blob_str = std::move(ss).str(); } } else { // External blob - model_blob_str = shared_context_.GetBinPath().filename().string(); + model_blob_str = PathToUTF8String(shared_context_.GetBinPath().filename().native()); } auto status = ep_ctx_handle_.AddOVEPCtxNodeToGraph(graph_body_viewer, diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 0e49c0f897bea..68aa9a157f4a2 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -269,7 +269,7 @@ Status CreateEPContextNodes(Model* model, } context_bin_path = context_bin_path + ToPathString("_qnn.bin"); - context_cache_name = std::filesystem::path(context_bin_path).filename().string(); + context_cache_name = PathToUTF8String(std::filesystem::path(context_bin_path).filename().native()); // If generate ctx.onnx with share_ep_context enabled, all ctx.onnx should point to the same ctx.bin if (share_ep_contexts) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc b/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc index fb76f2110cbc8..87340e5b3ebeb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc @@ -237,7 +237,7 @@ Serializer::Serializer(const ProfilingInfo& profiling_info, tracelogging_provider_ep_enabled_(tracelogging_provider_ep_enabled) { #ifdef QNN_SYSTEM_PROFILE_API_ENABLED std::filesystem::path output_fs_filepath(profiling_info.csv_output_filepath); - qnn_log_filename_ = output_fs_filepath.filename().string(); + qnn_log_filename_ = PathToUTF8String(output_fs_filepath.filename().native()); // Remove extension (assumed to be ".csv") then add "_qnn.log" size_t extension_start_idx = qnn_log_filename_.rfind("."); qnn_log_filename_ = qnn_log_filename_.substr(0, extension_start_idx); diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index 2a54bfea86e91..091b110d8c746 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -120,7 +120,7 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, attr_2->set_s(compute_capability); attr_3->set_name(ONNX_MODEL_FILENAME); attr_3->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_3->set_s(std::filesystem::path(onnx_model_path).filename().string()); + attr_3->set_s(PathToUTF8String(std::filesystem::path(onnx_model_path).filename().native())); attr_4->set_name(SOURCE); attr_4->set_type(onnx::AttributeProto_AttributeType_STRING); attr_4->set_s(kTensorrtExecutionProvider); diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 7cd02e5413407..f37c685cf2f28 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -618,7 +618,7 @@ Status Environment::CreateAndRegisterInternalEps() { Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path) { std::lock_guard lock{mutex_}; - std::string lib_file_name = std::filesystem::path(lib_path).filename().string(); + std::string lib_file_name = PathToUTF8String(std::filesystem::path(lib_path).filename().native()); Env::Default().GetTelemetryProvider().LogRegisterEpLibraryWithLibPath(registration_name, lib_file_name); std::vector internal_factories = {}; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 08b58f3de1a11..b873c95b496bb 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2612,7 +2612,7 @@ common::Status InferenceSession::Initialize() { // and log telemetry std::filesystem::path model_path = graph.ModelPath(); - std::string model_file_name = model_path.filename().string(); + std::string model_file_name = PathToUTF8String(model_path.filename().native()); bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); env.GetTelemetryProvider().LogSessionCreation( session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), @@ -4096,7 +4096,7 @@ void InferenceSession::LogAllSessions() { if (nullptr != model) { onnxruntime::Graph& graph = model->MainGraph(); std::filesystem::path model_path = graph.ModelPath(); - std::string model_file_name = model_path.filename().string(); + std::string model_file_name = PathToUTF8String(model_path.filename().native()); bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); std::string model_weight_type = session->GetWeightDataType(); std::string model_graph_hash = session->GetGraphHash(); diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py new file mode 100644 index 0000000000000..a6c4923cd961e --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Generate expected outputs for DeformConv tests using torchvision.ops.deform_conv2d. +Run with: .venv/bin/python onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py +Outputs C++-friendly std::vector initializer lists for pasting into deform_conv_op_test.cc + +Limitation: Uses symmetric padding only. PyTorch padding=(pad_h, pad_w) and ONNX pads +[pad_h, pad_w, pad_h, pad_w] are derived from a single (pad_h, pad_w) pair. Asymmetric +pads (e.g. pad_top != pad_bottom) would require PyTorch API support and are not generated. +""" + +import torch +import torchvision.ops + + +def _pair(x: int | tuple[int, int]) -> tuple[int, int]: + if isinstance(x, int): + return (x, x) + return x + + +def to_cpp_list(t: torch.Tensor, fmt="{:.6f}") -> str: + """Flatten tensor in NCHW order and format as C++ initializer list.""" + t = t.detach().float().contiguous() + return ", ".join(fmt.format(x) + "f" for x in t.flatten().tolist()) + + +def run_case( + name: str, + batch_sz: int, + n_in: int, + n_out: int, + n_weight_grps: int, + n_offset_grps: int, + kernel_h: int, + kernel_w: int, + stride: tuple[int, int] | int, + pad: tuple[int, int] | int, + dilation: tuple[int, int] | int, + in_h: int, + in_w: int, + seed: int = 42, +): + """Build inputs with seed, run deform_conv2d, print C++ snippets.""" + torch.manual_seed(seed) + stride_h, stride_w = _pair(stride) + pad_h, pad_w = _pair(pad) + dil_h, dil_w = _pair(dilation) + + out_h = (in_h + 2 * pad_h - (dil_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (kernel_w - 1) + 1)) // stride_w + 1 + + x = torch.rand(batch_sz, n_in, in_h, in_w, dtype=torch.float32) + offset = torch.randn(batch_sz, n_offset_grps * 2 * kernel_h * kernel_w, out_h, out_w, dtype=torch.float32) + mask = torch.randn(batch_sz, n_offset_grps * kernel_h * kernel_w, out_h, out_w, dtype=torch.float32) + weight = torch.randn(n_out, n_in // n_weight_grps, kernel_h, kernel_w, dtype=torch.float32) + bias = torch.randn(n_out, dtype=torch.float32) + + # Standard answer from torchvision + out = torchvision.ops.deform_conv2d( + x, + offset, + weight, + bias=bias, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dil_h, dil_w), + mask=mask, + ) + + # ONNX pads = [top, left, bottom, right] (symmetric: single pad_h, pad_w expanded) + pads_onnx = [pad_h, pad_w, pad_h, pad_w] + + print(f"// --- {name} (seed={seed}) ---") + print(f"// Shapes: X({batch_sz},{n_in},{in_h},{in_w}) W({n_out},{n_in // n_weight_grps},{kernel_h},{kernel_w})") + print(f"// stride=({stride_h},{stride_w}) pad=({pad_h},{pad_w}) dilation=({dil_h},{dil_w})") + print(f"// out_h={out_h} out_w={out_w}") + print() + print("std::vector X = {" + to_cpp_list(x) + "};") + print("std::vector W = {" + to_cpp_list(weight) + "};") + print("std::vector offset = {" + to_cpp_list(offset) + "};") + print("std::vector B = {" + to_cpp_list(bias) + "};") + print("std::vector mask = {" + to_cpp_list(mask) + "};") + print("std::vector expected_Y = {" + to_cpp_list(out) + "};") + print() + print( + "// Params: kernel_shape={" + f"{kernel_h}, {kernel_w}" + "}, stride={" + f"{stride_h}, {stride_w}" + "}, pads={" + + ", ".join(map(str, pads_onnx)) + + "}, dilations={" + + f"{dil_h}, {dil_w}" + + "}, group=" + + str(n_weight_grps) + + ", offset_group=" + + str(n_offset_grps) + ) + print() + return out + + +def main(): + print("// Generated by deform_conv_expected_gen.py (torchvision.ops.deform_conv2d)") + print() + + # Case 1: Same config as PyTorch TestDeformConv.get_fn_args (small batch for readability) + run_case( + "PyTorch get_fn_args style (batch=1)", + batch_sz=1, + n_in=6, + n_out=2, + n_weight_grps=2, + n_offset_grps=3, + kernel_h=3, + kernel_w=2, + stride=(2, 1), + pad=(1, 0), + dilation=(2, 1), + in_h=5, + in_w=4, + seed=42, + ) + + # Case 2: No mask (mask optional) - same config, then expected with mask=None + torch.manual_seed(42) + n_in, n_out = 6, 2 + n_weight_grps, n_offset_grps = 2, 3 + kH, kW = 3, 2 # noqa: N806 + stride_h, stride_w = 2, 1 + pad_h, pad_w = 1, 0 + dil_h, dil_w = 2, 1 + in_h, in_w = 5, 4 + batch_sz = 1 + out_h = (in_h + 2 * pad_h - (dil_h * (kH - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (kW - 1) + 1)) // stride_w + 1 + + x = torch.rand(batch_sz, n_in, in_h, in_w, dtype=torch.float32) + offset = torch.randn(batch_sz, n_offset_grps * 2 * kH * kW, out_h, out_w, dtype=torch.float32) + weight = torch.randn(n_out, n_in // n_weight_grps, kH, kW, dtype=torch.float32) + bias = torch.randn(n_out, dtype=torch.float32) + + out_no_mask = torchvision.ops.deform_conv2d( + x, + offset, + weight, + bias=bias, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dil_h, dil_w), + mask=None, + ) + print("// --- Same inputs, no mask (expected_Y when mask is omitted) ---") + print("std::vector expected_Y_no_mask = {" + to_cpp_list(out_no_mask) + "};") + print() + + # Case 3: groups=2, offset_group=2, non-zero offset (for GroupsWithNonZeroOffset test) + run_case( + "Groups with non-zero offset (batch=1, 2 groups)", + batch_sz=1, + n_in=4, + n_out=2, + n_weight_grps=2, + n_offset_grps=2, + kernel_h=2, + kernel_w=2, + stride=(1, 1), + pad=(0, 0), + dilation=(1, 1), + in_h=3, + in_w=3, + seed=123, + ) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc new file mode 100644 index 0000000000000..860c0d2f08b18 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -0,0 +1,948 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Unit tests for DeformConv (CPU and Cuda), aligned with PyTorch Vision deform_conv2d tests. +// Reference: https://github.com/pytorch/vision/blob/main/test/test_ops.py (TestDeformConv) + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/testdata/deform_conv_test_data.inc" +#include "test/unittest_util/conversion.h" + +#if defined(USE_CUDA) +#include "test/common/cuda_op_test_utils.h" +#endif + +namespace onnxruntime { +namespace test { + +namespace { + +// Parameters similar to PyTorch TestDeformConv::get_fn_args (smaller for speed). +struct DeformConvTestParams { + int64_t batch_sz; + int64_t n_in_channels; + int64_t n_out_channels; + int64_t n_weight_grps; + int64_t n_offset_grps; + std::vector kernel_shape; // {kH, kW} + std::vector stride; + std::vector pad; + std::vector dilation; + int64_t in_h; + int64_t in_w; +}; + +// Traits for type-specific DeformConv test behavior. +template +struct DeformConvTestTraits; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return v; } + static std::unordered_set ExcludedProviders() { + return {kTensorrtExecutionProvider, kNvTensorRTRTXExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-5f; } + static constexpr float DefaultAtol() { return 1e-5f; } +}; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return FloatsToMLFloat16s(v); } + static std::unordered_set ExcludedProviders() { + return {kCpuExecutionProvider, kNvTensorRTRTXExecutionProvider, kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-2f; } + static constexpr float DefaultAtol() { return 1e-2f; } +}; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { + return std::vector(v.begin(), v.end()); + } + static std::unordered_set ExcludedProviders() { + return {kTensorrtExecutionProvider, kNvTensorRTRTXExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; + } + static constexpr double DefaultRtol() { return 1e-8; } + static constexpr double DefaultAtol() { return 1e-8; } +}; + +#if defined(USE_CUDA) +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return FloatsToBFloat16s(v); } + static std::unordered_set ExcludedProviders() { + return {kCpuExecutionProvider, kNvTensorRTRTXExecutionProvider, kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-2f; } + static constexpr float DefaultAtol() { return 1e-2f; } +}; +#endif + +template +void RunDeformConvTest(const DeformConvTestParams& params, + const std::vector& X, + const std::vector& W, + const std::vector& offset, + const std::vector& B, + const std::vector* mask, + const std::vector& expected_Y, + int opset = 19, + decltype(DeformConvTestTraits::DefaultRtol()) rtol = DeformConvTestTraits::DefaultRtol(), + decltype(DeformConvTestTraits::DefaultAtol()) atol = DeformConvTestTraits::DefaultAtol(), + bool omit_bias = false) { + const int64_t kH = params.kernel_shape[0]; + const int64_t kW = params.kernel_shape[1]; + // ONNX pads format: [pad_top, pad_left, pad_bottom, pad_right] = [pad[0], pad[1], pad[2], pad[3]] + const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - + params.dilation[0] * (kH - 1) - 1) / + params.stride[0] + + 1; + const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - + params.dilation[1] * (kW - 1) - 1) / + params.stride[1] + + 1; + + OpTester test("DeformConv", opset); + test.AddAttribute("kernel_shape", params.kernel_shape); + test.AddAttribute("strides", params.stride); + test.AddAttribute("pads", params.pad); + test.AddAttribute("dilations", params.dilation); + test.AddAttribute("group", params.n_weight_grps); + test.AddAttribute("offset_group", params.n_offset_grps); + + const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; + const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; + const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; + const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; + + auto X_t = DeformConvTestTraits::Convert(X); + auto W_t = DeformConvTestTraits::Convert(W); + auto offset_t = DeformConvTestTraits::Convert(offset); + auto expected_Y_t = DeformConvTestTraits::Convert(expected_Y); + + test.AddInput("X", X_shape, X_t); + test.AddInput("W", W_shape, W_t); + test.AddInput("offset", offset_shape, offset_t); + if (omit_bias) { + test.AddOptionalInputEdge(); + } else { + auto B_t = DeformConvTestTraits::Convert(B); + test.AddInput("B", {params.n_out_channels}, B_t); + } + if (mask != nullptr) { + const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; + test.AddInput("mask", mask_shape, DeformConvTestTraits::Convert(*mask)); + } else { + test.AddOptionalInputEdge(); + } + + const float rtol_f = static_cast(rtol); + const float atol_f = static_cast(atol); + test.AddOutput("Y", Y_shape, expected_Y_t, false, rtol_f, atol_f); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", DeformConvTestTraits::ExcludedProviders()); +} + +// MinimalBilinear test: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). +// At (0,0) offset (0.5, 0.5) samples center of [1,2;3,4] -> 2.5. +template +void RunMinimalBilinearTest(int opset = 19, int min_cuda_arch = 0, bool omit_bias = false) { +#if defined(USE_CUDA) + if (min_cuda_arch > 0 && !HasCudaEnvironment(min_cuda_arch)) { + return; + } +#else + (void)min_cuda_arch; +#endif + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + // offset shape [N, 2*kH*kW, out_h, out_w] = [1, 2, 2, 2]: ch0=offset_h, ch1=offset_w (for kernel pt 0) + // Layout: offset[n,c,oh,ow]. Flattened (NCHW): [ch0@00, ch0@01, ch0@10, ch0@11, ch1@00, ch1@01, ch1@10, ch1@11] + // (0,0): (0.5, 0.5)->center of [1,2;3,4]->2.5; (0,1): (0,-1)->(0,0)->1; (1,0): (0,0)->3; (1,1): (0,0)->4 + std::vector offset = {0.5f, 0.f, 0.f, 0.f, 0.5f, -1.0f, 0.f, 0.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; + if (omit_bias) { + RunDeformConvTest(p, X, W, offset, {} /* B unused */, &mask, expected_Y, opset, + DeformConvTestTraits::DefaultRtol(), DeformConvTestTraits::DefaultAtol(), true); + } else { + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, opset, + DeformConvTestTraits::DefaultRtol(), DeformConvTestTraits::DefaultAtol(), false); + } +} +} // namespace + +// Minimal case: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). +TEST(DeformConvTest, MinimalBilinear) { + RunMinimalBilinearTest(); +} + +// Optional bias omitted: same as MinimalBilinear but B is not provided; output must match B=0. +TEST(DeformConvTest, OptionalBiasOmitted) { + RunMinimalBilinearTest(19, 0, true); +} + +// Minimal case FP16: Same as MinimalBilinear but in FP16 (CUDA-only). +#if defined(USE_CUDA) +TEST(DeformConvTest, MinimalBilinearFP16) { + RunMinimalBilinearTest(19, 530); +} + +// Minimal case BFloat16: Same as MinimalBilinear but in BFloat16 (CUDA-only, opset 22). +TEST(DeformConvTest, MinimalBilinearBFloat16) { + RunMinimalBilinearTest(22, 800); +} +#endif // defined(USE_CUDA) + +// Minimal case Double (FP64): Same as MinimalBilinear in double precision. +TEST(DeformConvTest, MinimalBilinearDouble) { + RunMinimalBilinearTest(); +} + +// Forward with mask and bias FP16 (CUDA-only; skip when CUDA not available). +#if defined(USE_CUDA) +TEST(DeformConvTest, ForwardWithMaskAndBiasFP16) { + int min_cuda_architecture = 530; // FP16 requires SM 5.3+ + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "DeformConv FP16: CUDA not available, skipping."; + return; + } + + DeformConvTestParams p = {}; + p.batch_sz = 2; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(p.batch_sz * p.n_in_channels * p.in_h * p.in_w); + const size_t w_size = static_cast(p.n_out_channels * (p.n_in_channels / p.n_weight_grps) * 2 * 2); + const size_t offset_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.5f, -0.5f}; + + const size_t y_size = static_cast(p.batch_sz * p.n_out_channels * out_h * out_w); + std::vector expected_Y(y_size); + for (int64_t b = 0; b < p.batch_sz; ++b) { + for (int64_t c = 0; c < p.n_out_channels; ++c) { + float val = (c % 2 == 0) ? 0.58f : -0.42f; + for (int64_t i = 0; i < out_h * out_w; ++i) { + expected_Y[b * p.n_out_channels * out_h * out_w + c * out_h * out_w + i] = val; + } + } + } + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} +#endif // defined(USE_CUDA) + +// With offset=0 and mask=1, Y = Conv(X,W) + B. Use small inputs and compute expected. +TEST(DeformConvTest, ForwardWithMaskAndBias) { + DeformConvTestParams p = {}; + p.batch_sz = 2; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(p.batch_sz * p.n_in_channels * p.in_h * p.in_w); + const size_t w_size = static_cast(p.n_out_channels * (p.n_in_channels / p.n_weight_grps) * 2 * 2); + const size_t offset_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); // zero offset -> regular grid sampling + std::vector mask(mask_size, 1.f); + std::vector B = {0.5f, -0.5f}; + + // With offset=0, mask=1: deform_conv equals grouped conv. Per ONNX, group 0 -> output ch 0, group 1 -> ch 1. + // Uniform X=0.1, W=0.1, 2x2 kernel -> 0.08 + B per channel; Y[:,0,:,:]=0.58, Y[:,1,:,:]=-0.42. + const size_t y_size = static_cast(p.batch_sz * p.n_out_channels * out_h * out_w); + std::vector expected_Y(y_size); + for (int64_t b = 0; b < p.batch_sz; ++b) { + for (int64_t c = 0; c < p.n_out_channels; ++c) { + float val = (c % 2 == 0) ? 0.58f : -0.42f; + for (int64_t i = 0; i < out_h * out_w; ++i) { + expected_Y[b * p.n_out_channels * out_h * out_w + c * out_h * out_w + i] = val; + } + } + } + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// No mask (optional): same as above but mask omitted; compare to run with ones mask via tolerance. +TEST(DeformConvTest, ForwardNoMask) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = 1 * 2 * 3 * 3; + const size_t w_size = 2 * 2 * 2 * 2; + const size_t offset_size = 1 * 2 * 2 * 2 * out_h * out_w; + const size_t y_size = 1 * 2 * out_h * out_w; + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector B(2, 0.f); + // No mask => mask=1. Zero offset => same as conv. Y = 4*2*0.1*0.1 = 0.08 per position. + std::vector expected_Y(y_size, 0.08f); + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + const std::vector X_shape = {p.batch_sz, p.n_in_channels, p.in_h, p.in_w}; + const std::vector W_shape = {p.n_out_channels, p.n_in_channels / p.n_weight_grps, 2, 2}; + const std::vector offset_shape = {p.batch_sz, p.n_offset_grps * 2 * 2 * 2, out_h, out_w}; + const std::vector Y_shape = {p.batch_sz, p.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W); + test.AddInput("offset", offset_shape, offset); + test.AddInput("B", {p.n_out_channels}, B); + test.AddOptionalInputEdge(); // no mask + test.AddOutput("Y", Y_shape, expected_Y, false, 1e-4f, 1e-4f); + std::unordered_set excluded = {kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +// Empty batch (N=0): allowed, same as Conv/ConvTranspose/Pool — output shape [0, oC, oH, oW]. +TEST(DeformConvTest, EmptyBatch) { + DeformConvTestParams p = {}; + p.batch_sz = 0; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X; + std::vector W = std::vector(2 * 2 * 2 * 2, 0.1f); + std::vector offset; + std::vector B(2, 0.f); + std::vector expected_Y; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + const std::vector X_shape = {0, p.n_in_channels, p.in_h, p.in_w}; + const std::vector W_shape = {p.n_out_channels, p.n_in_channels / p.n_weight_grps, 2, 2}; + const std::vector offset_shape = {0, p.n_offset_grps * 2 * 2 * 2, out_h, out_w}; + const std::vector Y_shape = {0, p.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W); + test.AddInput("offset", offset_shape, offset); + test.AddInput("B", {p.n_out_channels}, B); + test.AddOptionalInputEdge(); + test.AddOutput("Y", Y_shape, expected_Y); + std::unordered_set excluded = {kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +// Wrong offset channel count -> expect failure (like PyTorch test_wrong_sizes). +TEST(DeformConvTest, WrongOffsetShape) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 2 * 3 * 3, 0.1f); + std::vector W(2 * 2 * 2 * 2, 0.1f); + std::vector wrong_offset(1 * 2 * out_h * out_w); // wrong: only 2 channels instead of 8 + std::vector B(2, 0.f); + std::vector expected_Y(1 * 2 * out_h * out_w, 0.f); + + const std::vector offset_shape_wrong = {1, 2, out_h, out_w}; + const std::vector Y_shape_wrong = {1, 2, out_h, out_w}; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + test.AddInput("X", {1, 2, 3, 3}, X); + test.AddInput("W", {2, 2, 2, 2}, W); + test.AddInput("offset", offset_shape_wrong, wrong_offset); // invalid channels + test.AddInput("B", {2}, B); + test.AddOptionalInputEdge(); + test.AddOutput("Y", Y_shape_wrong, expected_Y); + std::unordered_set excluded = {kTensorrtExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectFailure, "Offset channel count must be offset_group * 2 * kH * kW", excluded); +} + +// Wrong mask channel count -> expect failure. +TEST(DeformConvTest, WrongMaskShape) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 2 * 3 * 3, 0.1f); + std::vector W(2 * 2 * 2 * 2, 0.1f); + const size_t offset_size = static_cast( + p.batch_sz * p.n_offset_grps * 2 * p.kernel_shape[0] * p.kernel_shape[1] * out_h * out_w); + std::vector offset(offset_size, 0.f); + std::vector B(2, 0.f); + std::vector wrong_mask(1 * 2 * out_h * out_w); // wrong: 2 instead of 4 + std::vector expected_Y(1 * 2 * out_h * out_w, 0.f); + + const std::vector mask_shape_wrong = {1, 2, out_h, out_w}; + const std::vector Y_shape_mask = {1, 2, out_h, out_w}; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + test.AddInput("X", {1, 2, 3, 3}, X); + test.AddInput("W", {2, 2, 2, 2}, W); + test.AddInput("offset", {1, 8, out_h, out_w}, offset); + test.AddInput("B", {2}, B); + test.AddInput("mask", mask_shape_wrong, wrong_mask); + test.AddOutput("Y", Y_shape_mask, expected_Y); + std::unordered_set excluded = {kTensorrtExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectFailure, "Mask channel count", excluded); +} + +// Opset 22 (same behavior, different opset). +TEST(DeformConvTest, Opset22) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + std::vector offset = {0.5f, 0.f, 0.f, 0.f, 0.5f, 0.f, 0.f, 0.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {2.5f, 2.f, 3.f, 4.f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 22); +} + +// Non-square kernel (kH != kW): 2x3 kernel, zero offset -> same as standard conv. +TEST(DeformConvTest, NonSquareKernel) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 3}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 4; + p.in_w = 5; + // ONNX output size: out_h = (4 - 1*(2-1) - 1)/1 + 1 = 3, out_w = (5 - 1*(3-1) - 1)/1 + 1 = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t x_size = static_cast(1 * 1 * 4 * 5); + const size_t w_size = static_cast(1 * 1 * 2 * 3); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 3 * out_h * out_w); // n_offset_grps * 2 * kH * kW * out_h * out_w + const size_t mask_size = static_cast(1 * 1 * 2 * 3 * out_h * out_w); // n_offset_grps * kH * kW * out_h * out_w + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // With offset=0, mask=1: each output = 6 * 0.1 * 0.1 = 0.06 (9 positions) + std::vector expected_Y(static_cast(out_h * out_w), 0.06f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Asymmetric stride (stride_h != stride_w): stride=(2,1), zero offset. +TEST(DeformConvTest, AsymmetricStride) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {2, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 5; + p.in_w = 4; + // out_h = (5 - 1*(2-1) - 1) / 2 + 1 = 2, out_w = (4 - 1*(2-1) - 1) / 1 + 1 = 3 + const int64_t out_h = 2; + const int64_t out_w = 3; + + const size_t x_size = static_cast(1 * 1 * 5 * 4); + const size_t w_size = static_cast(1 * 1 * 2 * 2); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + std::vector expected_Y(static_cast(out_h * out_w), 0.04f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// groups > 0 and non-zero offset; expected from deform_conv_expected_gen.py (seed=123). +TEST(DeformConvTest, GroupsWithNonZeroOffset) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + + std::vector X = {0.296112f, 0.516562f, 0.251671f, 0.688557f, 0.073972f, 0.866522f, 0.136580f, 0.102479f, 0.184056f, 0.726447f, 0.315254f, 0.687107f, 0.075635f, 0.196638f, 0.316412f, 0.401740f, 0.118568f, 0.827395f, 0.382084f, 0.660494f, 0.853572f, 0.593153f, 0.636725f, 0.982629f, 0.274495f, 0.658376f, 0.277542f, 0.857325f, 0.899328f, 0.039014f, 0.926823f, 0.738757f, 0.717884f, 0.705837f, 0.915650f, 0.433980f}; + std::vector W = {-1.182045f, -0.287745f, -0.604301f, 0.600237f, -1.420473f, -0.223828f, 0.430555f, -0.898857f, -0.017858f, 0.426403f, -0.765741f, -0.054514f, -0.732053f, 1.234742f, 1.186221f, -0.220099f}; + std::vector offset = {-0.388483f, -0.934346f, -0.499144f, -1.086653f, 0.962421f, 0.249208f, -0.484502f, -2.092915f, 0.098284f, -0.093507f, 0.266215f, -0.585035f, -0.343038f, -0.682148f, -0.988689f, -1.701830f, -1.220290f, 1.313853f, 1.053300f, 0.138805f, -0.204445f, -2.268529f, -0.913328f, -0.420363f, -0.659559f, -0.797928f, 0.183831f, 0.229347f, 0.617743f, -0.287578f, 0.821824f, 0.151178f, -0.044382f, 1.623557f, -2.322871f, 1.087831f, -0.063545f, -0.448641f, -1.278470f, -1.144004f, -0.152640f, 0.116741f, 0.440260f, -1.446546f, -0.558082f, -0.051696f, -0.908273f, 0.350683f, -0.394809f, 0.489227f, -0.216815f, -1.747165f, 1.722842f, 0.773806f, 0.404630f, -1.646126f, -0.595084f, -0.711218f, 0.622965f, -1.372881f, -0.128065f, -1.283835f, -0.290120f, 1.276741f}; + std::vector B = {0.983955f, 0.204512f}; + std::vector mask = {-0.031861f, -0.478956f, 0.766809f, 0.027468f, 0.047470f, -0.923866f, -1.060737f, -2.324446f, -2.062818f, 0.006375f, -0.989555f, 0.701609f, -0.982238f, 0.277031f, 0.645495f, -0.895681f, 0.492753f, -0.014078f, -0.274663f, -0.764091f, -0.587157f, 1.195165f, -1.209575f, -0.556008f, -0.077105f, 1.277377f, -1.459629f, -2.159528f, -0.706709f, -0.922245f, 3.895372f, -0.602697f}; + std::vector expected_Y = {0.971546f, 1.139858f, 0.452817f, 1.863882f, -0.565266f, 1.423187f, -2.462833f, -0.104923f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// Sampling out of bounds: offset pushes sampling to (-5,-5), BilinearInterpolate returns 0. +TEST(DeformConvTest, OutOfBoundsSampling) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + // out_h=out_w=2 (2x2 output), offset shape [1, 2, 2, 2] = 8 values. All (-5,-5) -> OOB -> 0 + std::vector offset = {-5.f, -5.f, -5.f, -5.f, -5.f, -5.f, -5.f, -5.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {0.f, 0.f, 0.f, 0.f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Dilation > 1: 2x2 kernel with dilation (2,2), zero offset -> 4 sample points with stride 2. +TEST(DeformConvTest, DilationGt1) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {2, 2}; + p.in_h = 5; + p.in_w = 5; + // out_h = (5 - 2*(2-1) - 1)/1 + 1 = 3, out_w = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t x_size = 25; + const size_t w_size = 4; + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Each output: 4 samples at (0,0),(0,2),(2,0),(2,2) -> 4 * 0.1 * 0.1 = 0.04 + std::vector expected_Y(static_cast(out_h * out_w), 0.04f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Decoupled groups: group=2, offset_group=1 (one offset map shared by all input channels). +TEST(DeformConvTest, DecoupledGroups) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(1 * 4 * 3 * 3); + const size_t w_size = static_cast(2 * 2 * 2 * 2); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f, 0.f}; + // Zero offset -> grouped conv. Per output ch: 2 in_ch * 4 kernel * 0.01 = 0.08 + std::vector expected_Y(static_cast(1 * 2 * out_h * out_w), 0.08f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Asymmetric padding: pads [top=1, left=0, bottom=0, right=1]; output 3x3, some positions have OOB samples. +TEST(DeformConvTest, AsymmetricPadding) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {1, 0, 0, 1}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + // out_h = (3+1+0-1*(2-1)-1)/1+1 = 3, out_w = (3+0+1-1-1)/1+1 = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(1 * 1 * 3 * 3, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Row 0: (0,0),(0,1) 2 valid -> 0.02; (0,2) only (0,2) in, (0,3) OOB -> 1 valid -> 0.01. Row 1/2: as before. + std::vector expected_Y = {0.02f, 0.02f, 0.01f, 0.04f, 0.04f, 0.02f, 0.04f, 0.04f, 0.02f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Tiny offset (near zero): offset (1e-6, 1e-6), sample ~(0,0) -> bilinear ≈ X[0,0]. Use 1x1 input for 1 output. +TEST(DeformConvTest, TinyOffset) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 1; + p.in_w = 1; + + std::vector X = {1.f}; + std::vector W = {1.f}; + std::vector offset = {1e-6f, 1e-6f}; + std::vector B = {0.f}; + std::vector mask = {1.f}; + std::vector expected_Y = {1.f}; // bilinear at (1e-6, 1e-6) ≈ 1 + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// Offset (0.5, 0.5) at each kernel point: sampling at (i+0.5, j+0.5) -> (0.5,0.5),(0.5,1.5),(1.5,0.5),(1.5,1.5). +// Only (0.5,0.5) is fully in-bounds for 2x2 input; others hit boundary (OOB gives 0). Result = 1.6875. +TEST(DeformConvTest, OffsetAtPixelCenters) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {0.25f, 0.25f, 0.25f, 0.25f}; + std::vector offset = { + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {1.6875f}; // op output: one center sample 2.5 + boundary samples + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Large batch (N=64) to trigger CUDA ComputeInternal chunking loop (b += n_parallel_imgs). +TEST(DeformConvTest, LargeBatchSize) { + const int64_t N = 64; + DeformConvTestParams p = {}; + p.batch_sz = N; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(N * 1 * 3 * 3); + const size_t offset_size = static_cast(N * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(N * 1 * 2 * 2 * out_h * out_w); + const size_t y_size = static_cast(N * 1 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + std::vector expected_Y(y_size, 0.04f); // 4 * 0.1 * 0.1 per position + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// group=1, offset_group=2: weights not grouped, offset/mask grouped. +TEST(DeformConvTest, Group1OffsetGroup2) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; // C must be divisible by offset_group + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(1 * 4 * 3 * 3); + const size_t w_size = static_cast(2 * 4 * 2 * 2); + const size_t offset_size = static_cast(1 * 2 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 2 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f, 0.f}; + // group=1: full conv. Each output: 4 in_ch * 4 kernel = 16 * 0.01 = 0.16 per channel, 2 out ch -> 0.16 each + std::vector expected_Y(static_cast(1 * 2 * out_h * out_w), 0.16f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Mask with zeros: exercises CUDA early-exit when mask_val == 0. +TEST(DeformConvTest, MaskWithZeros) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 1 * 3 * 3, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + std::vector offset(offset_size, 0.f); + // mask: (1, 4, 2, 2). Set all to 0 -> output should be 0. + std::vector mask(static_cast(1 * 1 * 2 * 2 * out_h * out_w), 0.f); + std::vector B = {0.f}; + std::vector expected_Y(static_cast(out_h * out_w), 0.f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Extreme aspect ratio (1x100): thin horizontal strip to verify coordinate indexing. +TEST(DeformConvTest, ExtremeAspectRatio) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 3}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 1; + p.in_w = 100; + // out_h = 1, out_w = (100 - 1*(3-1) - 1)/1 + 1 = 98 + const int64_t out_h = 1; + const int64_t out_w = 98; + + std::vector X(100, 0.1f); + std::vector W(1 * 1 * 1 * 3, 0.1f); + const size_t offset_size = static_cast(1 * 1 * 2 * 1 * 3 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 1 * 3 * out_h * out_w); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Each output: 3 * 0.1 * 0.1 = 0.03 + std::vector expected_Y(static_cast(out_h * out_w), 0.03f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// ONNX model data test: deform_conv_test_gen.py builds the ONNX model (via onnx.helper) +// and generates fixed inputs from torchvision (seed=123). This test is a model-loading/ +// integration smoke test that uses ORT-generated outputs from deform_conv_test.onnx as the reference. +TEST(DeformConvTest, OnnxModelTest) { + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", std::vector{2, 2}); + test.AddAttribute("strides", std::vector{1, 1}); + test.AddAttribute("pads", std::vector{0, 0, 0, 0}); + test.AddAttribute("dilations", std::vector{1, 1}); + test.AddAttribute("group", static_cast(2)); + test.AddAttribute("offset_group", static_cast(2)); + + test.AddInput("X", {1, 4, 3, 3}, kDeformConvOnnxTest_X); + test.AddInput("W", {2, 2, 2, 2}, kDeformConvOnnxTest_W); + test.AddInput("offset", {1, 16, 2, 2}, kDeformConvOnnxTest_offset); + test.AddInput("B", {2}, kDeformConvOnnxTest_B); + test.AddInput("mask", {1, 8, 2, 2}, kDeformConvOnnxTest_mask); + test.AddReferenceOutputs("testdata/deform_conv_test.onnx", 1e-4f); + + std::unordered_set excluded = {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc index 1eeb3683bc9aa..b2abe353693a2 100644 --- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" #include "test/common/trt_op_test_utils.h" @@ -906,5 +907,138 @@ TEST(RoiAlignTest, BatchIndicesNegative_CUDA) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); #endif } + +TEST(RoiAlignTest, Float16_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToFloat16({1.25f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoiAlignTest, Float16_Opset22) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 22); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToFloat16({1.25f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoiAlignTest, BFloat16_Opset22) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 22); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToBFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToBFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToBFloat16({1.25f})); + + test.SetOutputTolerance(0.05f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test half_pixel mode (default for Opset 16+) with Float16 on larger spatial dimensions. +// Uses 8x8 input (0..63), ROI [0,0,7,7], output 2x2, sampling_ratio=2. +// Expected values from ONNX reference implementation: {11.25, 14.75, 39.25, 42.75} +TEST(RoiAlignTest, Float16_HalfPixel_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + + std::vector X_val(64); + for (int i = 0; i < 64; ++i) X_val[i] = static_cast(i); + test.AddInput("X", {1, 1, 8, 8}, ToFloat16(X_val)); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 7., 7.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 2, 2}, ToFloat16({11.25f, 14.75f, 39.25f, 42.75f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test adaptive sampling (sampling_ratio=0) with Float16 on larger spatial dimensions. +// Uses 8x8 input (0..63), ROI [0,0,7,7], output 2x2, half_pixel mode. +// Adaptive: ceil(3.0/2)=2 samples per dim. +// Expected values from ONNX reference implementation: {11.39062, 14.875, 39.26562, 42.75} +TEST(RoiAlignTest, Float16_AdaptiveSampling_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 0); // adaptive + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + + std::vector X_val(64); + for (int i = 0; i < 64; ++i) X_val[i] = static_cast(i); + test.AddInput("X", {1, 1, 8, 8}, ToFloat16(X_val)); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 7., 7.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 2, 2}, + ToFloat16({11.39062f, 14.875f, 39.26562f, 42.75f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/pad_test.cc b/onnxruntime/test/providers/cpu/tensor/pad_test.cc index 9169f2e6b5ca9..990e4354c3626 100644 --- a/onnxruntime/test/providers/cpu/tensor/pad_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/pad_test.cc @@ -124,6 +124,37 @@ static void RunAllOpsetAllDomainPadTests( } } +#ifdef USE_CUDA +template +static void RunCudaOnlyOnnxOpsetPadTest( + int opset, + const std::vector& input_dims, + const std::vector& input, + const std::vector& pads, + T value, + const std::vector& output_dims, + const std::vector& output, + const std::string& mode = "constant") { + auto cuda_execution_provider = DefaultCudaExecutionProvider(); + if (cuda_execution_provider == nullptr) { + GTEST_SKIP() << "CUDA execution provider is not available"; + } + + OpTester test("Pad", opset); + if (mode != "constant") { + test.AddAttribute("mode", mode); + } + test.AddInput("data", input_dims, input); + test.AddInput("pads", {static_cast(pads.size())}, pads, true); + test.AddInput("value", {}, {value}, true); + test.AddOutput("output", output_dims, output); + + std::vector> execution_providers; + execution_providers.emplace_back(std::move(cuda_execution_provider)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + // Some of the tests can't run on TensorrtExecutionProvider because only constant mode and value 0 of "Pad" node is supported. // Those tests will fallback to other EP. @@ -199,6 +230,48 @@ TYPED_TEST(PadOpTest, Pad_Edge_1D) { "edge"); } +#ifdef USE_CUDA +TEST(PadOpTest, Pad_Edge_CudaOnly_MLFloat16_SupportedOpsets) { + const std::vector supported_opsets{18, 19, 20, 21, 22, 23, 24, 25}; + for (int opset : supported_opsets) { + SCOPED_TRACE(MakeString("opset: ", opset)); + RunCudaOnlyOnnxOpsetPadTest( + opset, + {3, 2}, + {MLFloat16(1.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f)}, + {0, 2, 0, 1}, + MLFloat16(0.0f), + {3, 5}, + {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(3.0f), MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(5.0f), MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(6.0f)}, + "edge"); + } +} + +TEST(PadOpTest, Pad_Wrap_CudaOnly_Float_SupportedOpsets) { + const std::vector supported_opsets{19, 20, 21, 22, 23, 24, 25}; + for (int opset : supported_opsets) { + SCOPED_TRACE(MakeString("opset: ", opset)); + RunCudaOnlyOnnxOpsetPadTest( + opset, + {3, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f}, + {0, 1, 0, 1}, + 0.0f, + {3, 4}, + {2.0f, 1.0f, 2.0f, 1.0f, + 4.0f, 3.0f, 4.0f, 3.0f, + 6.0f, 5.0f, 6.0f, 5.0f}, + "wrap"); + } +} +#endif + TYPED_TEST(PadOpTest, Pad_Constant_2D) { using T = TypeParam; RunAllOpsetAllDomainPadTests({2, 2}, @@ -1391,9 +1464,7 @@ TEST(PadOpTest, Pad_Wrap_NegativeFront_PositiveBack) { // Post-slice core: [4]; wrap 3 -> [4, 4, 4, 4] const std::vector expected_data = {4, 4, 4, 4}; - // CUDA registers only up to 18 and does not impl wrap mode - // so we force version to 19 to automatically exclude EPs that do not - // implement wrap mode similar to the above tests. + // Use opset 19 to exercise wrap mode, which is supported from Pad-19 onward. OpTester test("Pad", 19); test.AddInput("data", input_shape, input_data); test.AddInput("pads", {static_cast(pads.size())}, pads, true); diff --git a/onnxruntime/test/testdata/deform_conv_test.onnx b/onnxruntime/test/testdata/deform_conv_test.onnx new file mode 100644 index 0000000000000..b643014e44acb Binary files /dev/null and b/onnxruntime/test/testdata/deform_conv_test.onnx differ diff --git a/onnxruntime/test/testdata/deform_conv_test_data.inc b/onnxruntime/test/testdata/deform_conv_test_data.inc new file mode 100644 index 0000000000000..206d8517dd3e3 --- /dev/null +++ b/onnxruntime/test/testdata/deform_conv_test_data.inc @@ -0,0 +1,10 @@ +// Auto-generated by deform_conv_test_gen.py - do not edit + +#include + +static const std::vector kDeformConvOnnxTest_X = {0.296111941f, 0.516562283f, 0.251670718f, 0.68855679f, 0.0739724636f, 0.866521955f, 0.136579871f, 0.102479041f, 0.184056461f, 0.726446748f, 0.315253913f, 0.687106669f, 0.075635314f, 0.196638167f, 0.316411972f, 0.401740134f, 0.118568301f, 0.82739538f, 0.382084429f, 0.660493851f, 0.853571773f, 0.593153f, 0.636725366f, 0.982629359f, 0.274495304f, 0.658375621f, 0.277541935f, 0.857324839f, 0.899328232f, 0.0390138626f, 0.926822901f, 0.738757193f, 0.717883527f, 0.705837429f, 0.915649533f, 0.433980227f}; +static const std::vector kDeformConvOnnxTest_W = {-1.18204546f, -0.287744999f, -0.604300678f, 0.600236714f, -1.42047262f, -0.223827749f, 0.430554837f, -0.89885664f, -0.0178579595f, 0.426403075f, -0.765740693f, -0.0545141846f, -0.732052684f, 1.23474216f, 1.18622088f, -0.220098898f}; +static const std::vector kDeformConvOnnxTest_offset = {-0.388483077f, -0.934345901f, -0.499144107f, -1.08665264f, 0.962421f, 0.249208495f, -0.484502077f, -2.09291434f, 0.0982837752f, -0.0935074314f, 0.266214728f, -0.585035503f, -0.343037993f, -0.682147384f, -0.988689423f, -1.70183039f, -1.2202903f, 1.31385386f, 1.05329967f, 0.138805181f, -0.204444751f, -2.26852894f, -0.913327932f, -0.420362711f, -0.659559608f, -0.797927678f, 0.18383126f, 0.229347408f, 0.617742658f, -0.287577927f, 0.821824312f, 0.151177585f, -0.0443819836f, 1.62355745f, -2.32287097f, 1.08783054f, -0.0635453761f, -0.448640704f, -1.27846932f, -1.14400387f, -0.152640373f, 0.116741188f, 0.44026047f, -1.44654655f, -0.558081627f, -0.0516963229f, -0.90827328f, 0.350683212f, -0.394808769f, 0.489227712f, -0.216814891f, -1.74716449f, 1.72284174f, 0.773806036f, 0.404629797f, -1.64612663f, -0.59508425f, -0.711217523f, 0.622964859f, -1.37288189f, -0.128064156f, -1.28383458f, -0.290120065f, 1.27674019f}; +static const std::vector kDeformConvOnnxTest_B = {0.983955026f, 0.204511523f}; +static const std::vector kDeformConvOnnxTest_mask = {-0.0318612382f, -0.478955716f, 0.766808629f, 0.0274681915f, 0.0474699028f, -0.92386651f, -1.06073678f, -2.32444572f, -2.06281757f, 0.00637452863f, -0.989554703f, 0.701609194f, -0.982237995f, 0.277030349f, 0.645495057f, -0.895680785f, 0.492752999f, -0.0140781598f, -0.274662733f, -0.764091492f, -0.58715719f, 1.1951654f, -1.20957518f, -0.556007624f, -0.0771045536f, 1.27737665f, -1.45962942f, -2.15952778f, -0.70670861f, -0.92224431f, 3.89537215f, -0.602696717f}; +static const std::vector kDeformConvOnnxTest_expected_Y = {0.971546292f, 1.1398586f, 0.452816963f, 1.86388242f, -0.565265715f, 1.42318761f, -2.46283293f, -0.104923099f}; diff --git a/onnxruntime/test/testdata/deform_conv_test_data.npz b/onnxruntime/test/testdata/deform_conv_test_data.npz new file mode 100644 index 0000000000000..68639f753501d Binary files /dev/null and b/onnxruntime/test/testdata/deform_conv_test_data.npz differ diff --git a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py new file mode 100644 index 0000000000000..120fb1ed4c211 --- /dev/null +++ b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py @@ -0,0 +1,194 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Generate DeformConv ONNX model and test data for cross-platform validation. + +Based on ONNX DeformConv spec (opset 19+): https://onnx.ai/onnx/operators/onnx__DeformConv.html +Uses a moderately complex config: groups=2, offset_group=2, 2x2 kernel, non-zero offsets. +Reference output from torchvision.ops.deform_conv2d. + +Limitation: Uses symmetric padding only. PyTorch padding=(pad_h, pad_w) and ONNX pads +[pad_top, pad_left, pad_bottom, pad_right] = [pad_h, pad_w, pad_h, pad_w]. Asymmetric +pads (e.g. pad_top != pad_bottom) would require PyTorch API support and are not generated. + +Run from repo root: + python onnxruntime/test/testdata/nn/deform_conv_test_gen.py + +Outputs: + - deform_conv_test.onnx + - deform_conv_test_data.npz (X, W, offset, B, mask, expected_Y) + - deform_conv_test_data.inc (C++ arrays for op test) +""" + +from pathlib import Path + +import numpy as np +from onnx import TensorProto, checker, helper, save + +try: + import onnxruntime as ort +except ImportError: + ort = None +import torch +import torchvision.ops + +# Config: groups=2, offset_group=2, 2x2 kernel (from deform_conv_expected_gen Case 3) +BATCH = 1 +N_IN = 4 +N_OUT = 2 +N_WEIGHT_GRPS = 2 +N_OFFSET_GRPS = 2 +KH, KW = 2, 2 +STRIDE_H, STRIDE_W = 1, 1 +PAD_H, PAD_W = 0, 0 +DIL_H, DIL_W = 1, 1 +IN_H, IN_W = 3, 3 +SEED = 123 + +OUT_H = (IN_H + 2 * PAD_H - (DIL_H * (KH - 1) + 1)) // STRIDE_H + 1 +OUT_W = (IN_W + 2 * PAD_W - (DIL_W * (KW - 1) + 1)) // STRIDE_W + 1 + + +def _generate_reference(): + """Generate inputs and expected output via torchvision.ops.deform_conv2d.""" + torch.manual_seed(SEED) + x = torch.rand(BATCH, N_IN, IN_H, IN_W, dtype=torch.float32) + offset = torch.randn(BATCH, N_OFFSET_GRPS * 2 * KH * KW, OUT_H, OUT_W, dtype=torch.float32) + mask = torch.randn(BATCH, N_OFFSET_GRPS * KH * KW, OUT_H, OUT_W, dtype=torch.float32) + weight = torch.randn(N_OUT, N_IN // N_WEIGHT_GRPS, KH, KW, dtype=torch.float32) + bias = torch.randn(N_OUT, dtype=torch.float32) + + out = torchvision.ops.deform_conv2d( + x, + offset, + weight, + bias=bias, + stride=(STRIDE_H, STRIDE_W), + padding=(PAD_H, PAD_W), + dilation=(DIL_H, DIL_W), + mask=mask, + ) + + return { + "X": x.numpy(), + "W": weight.numpy(), + "offset": offset.numpy(), + "B": bias.numpy(), + "mask": mask.numpy(), + "expected_Y": out.numpy(), + } + + +def _build_onnx_model(): + """Build DeformConv ONNX model. ONNX pads = [pad_top, pad_left, pad_bottom, pad_right].""" + # Symmetric padding only: (pad_h, pad_w) -> [pad_h, pad_w, pad_h, pad_w] + pads = [PAD_H, PAD_W, PAD_H, PAD_W] + + node = helper.make_node( + "DeformConv", + inputs=["X", "W", "offset", "B", "mask"], + outputs=["Y"], + kernel_shape=[KH, KW], + strides=[STRIDE_H, STRIDE_W], + pads=pads, + dilations=[DIL_H, DIL_W], + group=N_WEIGHT_GRPS, + offset_group=N_OFFSET_GRPS, + ) + + graph = helper.make_graph( + [node], + "DeformConvTest", + [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, [BATCH, N_IN, IN_H, IN_W]), + helper.make_tensor_value_info("W", TensorProto.FLOAT, [N_OUT, N_IN // N_WEIGHT_GRPS, KH, KW]), + helper.make_tensor_value_info( + "offset", TensorProto.FLOAT, [BATCH, N_OFFSET_GRPS * 2 * KH * KW, OUT_H, OUT_W] + ), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [N_OUT]), + helper.make_tensor_value_info("mask", TensorProto.FLOAT, [BATCH, N_OFFSET_GRPS * KH * KW, OUT_H, OUT_W]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [BATCH, N_OUT, OUT_H, OUT_W])], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 19)]) + checker.check_model(model) + return model + + +def _to_cpp_array(name: str, arr: np.ndarray) -> str: + """Format numpy array as C++ initializer list.""" + flat = arr.flatten().tolist() + vals = ", ".join(f"{x:.9g}f" for x in flat) + return f"static const std::vector {name} = {{{vals}}};" + + +def _write_cpp_inc(data: dict, inc_path: Path) -> None: + """Write C++ include file with test data.""" + lines = [ + "// Auto-generated by deform_conv_test_gen.py - do not edit", + "", + "#include ", + "", + _to_cpp_array("kDeformConvOnnxTest_X", data["X"]), + _to_cpp_array("kDeformConvOnnxTest_W", data["W"]), + _to_cpp_array("kDeformConvOnnxTest_offset", data["offset"]), + _to_cpp_array("kDeformConvOnnxTest_B", data["B"]), + _to_cpp_array("kDeformConvOnnxTest_mask", data["mask"]), + _to_cpp_array("kDeformConvOnnxTest_expected_Y", data["expected_Y"]), + "", + ] + inc_path.write_text("\n".join(lines), encoding="utf-8") + + +def main(): + # Output to testdata/ root (same as layernorm.onnx, attention_past_state.onnx, etc.) + script_dir = Path(__file__).resolve().parent + assert script_dir.name == "nn", "Script must live in testdata/nn/" + testdata_root = script_dir.parent + model_path = testdata_root / "deform_conv_test.onnx" + data_path = testdata_root / "deform_conv_test_data.npz" + inc_path = testdata_root / "deform_conv_test_data.inc" + + print("Generating reference via torchvision.ops.deform_conv2d...") + data = _generate_reference() + + print("Building ONNX model...") + model = _build_onnx_model() + save(model, str(model_path)) + print(f" Saved {model_path}") + + np.savez(str(data_path), **data) + print(f" Saved {data_path}") + + _write_cpp_inc(data, inc_path) + print(f" Saved {inc_path}") + + # Validate with onnxruntime if available + if ort is not None: + print("Validating with ONNX Runtime...") + sess = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + ort_out = sess.run( + ["Y"], + { + "X": data["X"], + "W": data["W"], + "offset": data["offset"], + "B": data["B"], + "mask": data["mask"], + }, + )[0] + + rtol, atol = 1e-4, 1e-4 + if np.allclose(ort_out, data["expected_Y"], rtol=rtol, atol=atol): + print(" PASS: ORT output matches reference.") + else: + diff = np.abs(ort_out.astype(np.float64) - data["expected_Y"].astype(np.float64)) + print(f" FAIL: max |diff|={diff.max()}, mean={diff.mean()}") + else: + print(" (onnxruntime not installed; skip validation)") + + +if __name__ == "__main__": + main() diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index f767ef110561a..88d2981e2ccaa 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -194,6 +194,7 @@ stages: - template: ../templates/py-linux-qnn.yml parameters: machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' + QnnSdk: ${{ parameters.qnn_sdk_version }} extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} is1ES: true