diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 872d3b182c310..e67eda41d2e0e 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -155,6 +155,86 @@ jobs: working-directory: ${{ github.workspace }}\csharp continue-on-error: true + webgpu_plugin_build_x64_RelWithDebInfo: + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "JobId=webgpu_plugin_build_x64_RelWithDebInfo-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] + timeout-minutes: 300 + env: + OnnxRuntimeBuildDirectory: ${{ github.workspace }} + setVcvars: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: "0" + DocUpdateNeeded: false + NVIDIA_TF32_OVERRIDE: "0" + ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" + steps: + - name: Checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + submodules: none + + - name: Setup Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: "3.12" + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + run: python -m pip install -r tools\ci_build\github\windows\python\requirements.txt + shell: cmd + working-directory: ${{ github.workspace }} + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: "20.x" + + - uses: actions/cache@v5 + id: onnx-node-tests-cache + with: + path: ${{ github.workspace }}/js/test/ + key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} + + - name: Build and Test + shell: pwsh + run: | + python.exe ${{ github.workspace }}\tools\ci_build\build.py ` + --config RelWithDebInfo ` + --build_dir ${{ github.workspace }} ` + --skip_submodule_sync ` + --parallel ` + --use_binskim_compliant_compile_flags ` + --cmake_generator "Visual Studio 17 2022" ` + --enable_onnx_tests ` + --use_webgpu shared_lib ` + --wgsl_template static ` + --use_vcpkg --use_vcpkg_ms_internal_asset_cache ` + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_DAWN_BACKEND_D3D12=1 onnxruntime_ENABLE_DAWN_BACKEND_VULKAN=1 ` + --disable_rtti ` + --enable_lto + + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + - name: Publish artifacts + uses: actions/upload-artifact@v4 + with: + name: webgpu-plugin-binaries + path: | + ${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_webgpu.dll + ${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_webgpu.pdb + ${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/dxcompiler.dll + ${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/dxil.dll + webgpu_external_dawn_build_x64_RelWithDebInfo: runs-on: [ "self-hosted", diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 0156e46b86bc4..4f75a8b105ec2 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -34,6 +34,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.h ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp + ${MLAS_SRC_DIR}/silu.cpp + ${MLAS_SRC_DIR}/gelu.cpp ${MLAS_SRC_DIR}/compute.cpp ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp @@ -201,6 +203,14 @@ function(setup_mlas_source_for_windows) ) set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") + set(mlas_platform_srcs_avx512 + ${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ) + + set_source_files_properties(${mlas_platform_srcs_avx512} PROPERTIES COMPILE_FLAGS "/arch:AVX512") + target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/dgemm.cpp ${mlas_platform_srcs_avx} @@ -212,7 +222,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${mlas_platform_srcs_avx512} ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp @@ -764,6 +774,8 @@ endif() ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S + ${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp ) set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index be7f2613a6272..cd29e4dad0a17 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -56,7 +56,7 @@ endif() source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs}) - onnxruntime_add_shared_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) + onnxruntime_add_shared_library_module(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_webgpu ${REPO_ROOT}/include/onnxruntime/core/session onnxruntime_common @@ -119,6 +119,12 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") message(FATAL_ERROR "WebGPU EP shared library build is not supported on Emscripten. Please use static library build.") endif() + + # Configure precompiled headers for shared library build + # PCH ensures ep/adapters.h is included first and improves compilation speed + target_precompile_headers(onnxruntime_providers_webgpu PRIVATE + "${REPO_ROOT}/include/onnxruntime/ep/adapters.h" + ) endif() set_target_properties(onnxruntime_providers_webgpu PROPERTIES CXX_STANDARD_REQUIRED ON) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 9ae3e79d86443..8137f8b3a2529 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1042,6 +1042,18 @@ function(onnxruntime_apply_test_target_workarounds target) endif() endfunction() +# Set environment variables for plugin EP tests when run via CTest. +function(onnxruntime_set_plugin_ep_test_environment target) + if(onnxruntime_USE_WEBGPU AND onnxruntime_USE_EP_API_ADAPTERS) + set(ORT_PLUGIN_EP_JSON_CONFIG "{\"ep_library_registration_name\": \"WebGPU_PluginEP\", \"ep_library_path\": \"$\", \"selected_ep_name\": \"WebGpuExecutionProvider\"}") + set_tests_properties(${target} PROPERTIES + ENVIRONMENT "ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON=${ORT_PLUGIN_EP_JSON_CONFIG}" + ) + # TODO: add for other plugin EPs if needed + # elseif() + endif() +endfunction() + function(onnxruntime_apply_emscripten_test_link_settings target) if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set_target_properties(${target} PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js) @@ -1250,6 +1262,7 @@ block() ) onnxruntime_apply_test_target_workarounds(onnxruntime_provider_test) + onnxruntime_set_plugin_ep_test_environment(onnxruntime_provider_test) # Expose QNN SDK headers to unit tests via an interface target if(onnxruntime_USE_QNN) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 39c9145a40912..625cc4e09ca13 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -103,6 +103,8 @@ Do not modify directly.* |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |||[17, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DeformConv|*in* X:**T**
*in* W:**T**
*in* offset:**T**
*in* B:**T**
*in* mask:**T**
*out* Y:**T**|22+|**T** = tensor(double), tensor(float)| +|||[19, 21]|**T** = tensor(double), tensor(float)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(uint8)| |||[1, 10]|**T** = tensor(double), tensor(float)| @@ -697,6 +699,8 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| +|DeformConv|*in* X:**T**
*in* W:**T**
*in* offset:**T**
*in* B:**T**
*in* mask:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -843,7 +847,12 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|25+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||24|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||23|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[21, 22]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[19, 20]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||18|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -902,7 +911,9 @@ Do not modify directly.* |||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| +|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|22+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(int64)| +|||[16, 21]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(int64)| +|||[10, 15]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |RotaryEmbedding|*in* X:**T**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**M**
*out* Y:**T**|23+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index f0a99bc11c8b3..a9d9ac8323b16 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -391,6 +391,13 @@ static const char* const kOrtSessionOptionsMlasDisableKleidiAi = "mlas.disable_k // If not provided, default is 4. static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; +// Block size used when converting per-tensor or per-axis DQ + MatMul to MatMulNBits. +// Only applies to DQ nodes without an existing block_size attribute (i.e., per-tensor or per-axis quantization). +// Positive value: explicit block_size (must be power-of-2 and >= 16, e.g., 16, 32, 64, 128). +// "0" or not provided: use default block_size of 32. +// "-1": heuristic - largest power-of-2 <= min(K, 256) that minimizes padding. +static const char* const kOrtSessionOptionsQDQMatMulNBitsBlockSize = "session.qdq_matmulnbits_block_size"; + // Enable the DQ->MatMulNBits fusion graph transformer. // "0": disabled (default). "1": enabled. // This is typically set automatically by InferenceSession when the NvTensorRTRTX EP is registered. diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h index 2765069ebf336..4f107ae72c0e9 100644 --- a/include/onnxruntime/ep/adapter/allocator.h +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -18,23 +18,50 @@ namespace adapter { /// class Allocator : public OrtAllocator { public: + /** + * Create from an existing AllocatorPtr. + */ explicit Allocator(const OrtMemoryInfo* memory_info, AllocatorPtr impl) - : OrtAllocator{}, memory_info_(memory_info), impl_(impl) { + : Allocator{memory_info} { + ORT_ENFORCE(impl != nullptr, "Allocator implementation cannot be null."); + impl_ = impl; + } + + using AllocatorFactory = AllocatorPtr (*)(const OrtMemoryInfo& memory_info); + + /** + * Create from an AllocatorFactory, which will be called lazily when the first allocation is made. + */ + explicit Allocator(const OrtMemoryInfo* memory_info, AllocatorFactory get_allocator_impl) + : Allocator{memory_info} { + get_allocator_impl_ = get_allocator_impl; + } + + private: + explicit Allocator(const OrtMemoryInfo* memory_info) + : OrtAllocator{}, memory_info_(memory_info) { version = ORT_API_VERSION; Alloc = AllocImpl; Free = FreeImpl; Info = InfoImpl; } + AllocatorPtr GetImpl() { + if (!impl_) { + std::call_once(init_flag_, [this]() { + impl_ = get_allocator_impl_(*memory_info_); + }); + } + return impl_; + } - private: static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept { auto* allocator = static_cast(this_ptr); - return allocator->impl_->Alloc(size); + return allocator->GetImpl()->Alloc(size); } static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept { auto* allocator = static_cast(this_ptr); - allocator->impl_->Free(p); + allocator->GetImpl()->Free(p); } static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept { @@ -44,6 +71,8 @@ class Allocator : public OrtAllocator { const OrtMemoryInfo* memory_info_; AllocatorPtr impl_; + AllocatorFactory get_allocator_impl_; + std::once_flag init_flag_; }; } // namespace adapter diff --git a/include/onnxruntime/ep/adapter/ep.h b/include/onnxruntime/ep/adapter/ep.h index 34fc7682a8138..ca0a8c9599eda 100644 --- a/include/onnxruntime/ep/adapter/ep.h +++ b/include/onnxruntime/ep/adapter/ep.h @@ -27,6 +27,7 @@ class Ep : public OrtEp { profiler_{impl_->GetProfiler()}, temp_space_cpu_allocator_{temp_space_cpu_allocator}, temp_space_allocator_{temp_space_allocator} { + ort_version_supported = ORT_API_VERSION; } public: diff --git a/include/onnxruntime/ep/adapter/op_kernel_info.h b/include/onnxruntime/ep/adapter/op_kernel_info.h index bd6172a668e33..644cb30788ec6 100644 --- a/include/onnxruntime/ep/adapter/op_kernel_info.h +++ b/include/onnxruntime/ep/adapter/op_kernel_info.h @@ -19,7 +19,7 @@ #include "tensor_helper.h" namespace onnxruntime { -struct DataTransferManager; +class DataTransferManager; struct IExecutionProvider; } // namespace onnxruntime diff --git a/include/onnxruntime/ep/common.h b/include/onnxruntime/ep/common.h index 03cd571461755..0e779ba3d4081 100644 --- a/include/onnxruntime/ep/common.h +++ b/include/onnxruntime/ep/common.h @@ -41,3 +41,24 @@ OrtStatus* _status = (status_expr); \ Ort::Status _ignored{_status}; \ } while (false) + +// Helper macros to convert exceptions to OrtStatus* return values. +// Usage: +// EXCEPTION_TO_RETURNED_STATUS_BEGIN +// ... code that may throw ... +// EXCEPTION_TO_RETURNED_STATUS_END +#define EXCEPTION_TO_RETURNED_STATUS_BEGIN try { +#define EXCEPTION_TO_RETURNED_STATUS_END \ + } \ + catch (const Ort::Exception& ex) { \ + Ort::Status status(ex); \ + return status.release(); \ + } \ + catch (const std::exception& ex) { \ + Ort::Status status(ex.what(), ORT_EP_FAIL); \ + return status.release(); \ + } \ + catch (...) { \ + Ort::Status status("Unknown exception", ORT_EP_FAIL); \ + return status.release(); \ + } diff --git a/js/package-lock.json b/js/package-lock.json index 22fb22757e94b..1ba8fc900bbd8 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4,7 +4,6 @@ "requires": true, "packages": { "": { - "name": "js", "license": "MIT", "devDependencies": { "@eslint/compat": "^1.4.0", @@ -3014,11 +3013,10 @@ } }, "node_modules/flatted": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", - "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", - "dev": true, - "license": "ISC" + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", + "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", + "dev": true }, "node_modules/for-each": { "version": "0.3.5", @@ -7946,9 +7944,9 @@ } }, "flatted": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", - "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", + "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", "dev": true }, "for-each": { diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index d8a273ef6825f..3f9ff05a72f97 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -7149,9 +7149,9 @@ } }, "node_modules/flatted": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", - "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", + "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", "dev": true, "license": "ISC" }, diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e486e4b0e043d..eeff82484ef5f 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -278,7 +278,7 @@ export class WebGpuBackend { value: this.device, writable: false, enumerable: true, - configurable: false, + configurable: true, // Allow deletion when device is destroyed }); Object.defineProperty(this.env.webgpu, 'adapter', { value: adapter, @@ -296,6 +296,14 @@ export class WebGpuBackend { this.querySet.destroy(); } this.gpuDataManager.dispose(); + + // Clear the device reference when it's lost to allow new sessions to create a fresh device + // This handles the case where preserve_device=false (default) causes the C++ side to destroy the device + if (this.device && this.env?.webgpu) { + void this.device.lost.then(() => { + delete (this.env.webgpu as unknown as Record).device; + }); + } } getCommandEncoder(): GPUCommandEncoder { diff --git a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json index 62b4df5806eda..ed0559c85ee1b 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json @@ -493,9 +493,9 @@ "license": "MIT" }, "node_modules/@rollup/rollup-android-arm-eabi": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.35.0.tgz", - "integrity": "sha512-uYQ2WfPaqz5QtVgMxfN6NpLD+no0MYHDBywl7itPYd3K5TjjSghNKmX8ic9S8NU8w81NVhJv/XojcHptRly7qQ==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.59.0.tgz", + "integrity": "sha512-upnNBkA6ZH2VKGcBj9Fyl9IGNPULcjXRlg0LLeaioQWueH30p6IXtJEbKAgvyv+mJaMxSm1l6xwDXYjpEMiLMg==", "cpu": [ "arm" ], @@ -507,9 +507,9 @@ ] }, "node_modules/@rollup/rollup-android-arm64": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.35.0.tgz", - "integrity": "sha512-FtKddj9XZudurLhdJnBl9fl6BwCJ3ky8riCXjEw3/UIbjmIY58ppWwPEvU3fNu+W7FUsAsB1CdH+7EQE6CXAPA==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.59.0.tgz", + "integrity": "sha512-hZ+Zxj3SySm4A/DylsDKZAeVg0mvi++0PYVceVyX7hemkw7OreKdCvW2oQ3T1FMZvCaQXqOTHb8qmBShoqk69Q==", "cpu": [ "arm64" ], @@ -521,9 +521,9 @@ ] }, "node_modules/@rollup/rollup-darwin-arm64": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.35.0.tgz", - "integrity": "sha512-Uk+GjOJR6CY844/q6r5DR/6lkPFOw0hjfOIzVx22THJXMxktXG6CbejseJFznU8vHcEBLpiXKY3/6xc+cBm65Q==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.59.0.tgz", + "integrity": "sha512-W2Psnbh1J8ZJw0xKAd8zdNgF9HRLkdWwwdWqubSVk0pUuQkoHnv7rx4GiF9rT4t5DIZGAsConRE3AxCdJ4m8rg==", "cpu": [ "arm64" ], @@ -535,9 +535,9 @@ ] }, "node_modules/@rollup/rollup-darwin-x64": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.35.0.tgz", - "integrity": "sha512-3IrHjfAS6Vkp+5bISNQnPogRAW5GAV1n+bNCrDwXmfMHbPl5EhTmWtfmwlJxFRUCBZ+tZ/OxDyU08aF6NI/N5Q==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.59.0.tgz", + "integrity": "sha512-ZW2KkwlS4lwTv7ZVsYDiARfFCnSGhzYPdiOU4IM2fDbL+QGlyAbjgSFuqNRbSthybLbIJ915UtZBtmuLrQAT/w==", "cpu": [ "x64" ], @@ -549,9 +549,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-arm64": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.35.0.tgz", - "integrity": "sha512-sxjoD/6F9cDLSELuLNnY0fOrM9WA0KrM0vWm57XhrIMf5FGiN8D0l7fn+bpUeBSU7dCgPV2oX4zHAsAXyHFGcQ==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.59.0.tgz", + "integrity": "sha512-EsKaJ5ytAu9jI3lonzn3BgG8iRBjV4LxZexygcQbpiU0wU0ATxhNVEpXKfUa0pS05gTcSDMKpn3Sx+QB9RlTTA==", "cpu": [ "arm64" ], @@ -563,9 +563,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-x64": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.35.0.tgz", - "integrity": "sha512-2mpHCeRuD1u/2kruUiHSsnjWtHjqVbzhBkNVQ1aVD63CcexKVcQGwJ2g5VphOd84GvxfSvnnlEyBtQCE5hxVVw==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.59.0.tgz", + "integrity": "sha512-d3DuZi2KzTMjImrxoHIAODUZYoUUMsuUiY4SRRcJy6NJoZ6iIqWnJu9IScV9jXysyGMVuW+KNzZvBLOcpdl3Vg==", "cpu": [ "x64" ], @@ -577,9 +577,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-gnueabihf": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.35.0.tgz", - "integrity": "sha512-mrA0v3QMy6ZSvEuLs0dMxcO2LnaCONs1Z73GUDBHWbY8tFFocM6yl7YyMu7rz4zS81NDSqhrUuolyZXGi8TEqg==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.59.0.tgz", + "integrity": "sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==", "cpu": [ "arm" ], @@ -591,9 +591,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-musleabihf": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.35.0.tgz", - "integrity": "sha512-DnYhhzcvTAKNexIql8pFajr0PiDGrIsBYPRvCKlA5ixSS3uwo/CWNZxB09jhIapEIg945KOzcYEAGGSmTSpk7A==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.59.0.tgz", + "integrity": "sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==", "cpu": [ "arm" ], @@ -605,9 +605,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.35.0.tgz", - "integrity": "sha512-uagpnH2M2g2b5iLsCTZ35CL1FgyuzzJQ8L9VtlJ+FckBXroTwNOaD0z0/UF+k5K3aNQjbm8LIVpxykUOQt1m/A==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.59.0.tgz", + "integrity": "sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==", "cpu": [ "arm64" ], @@ -619,9 +619,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-musl": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.35.0.tgz", - "integrity": "sha512-XQxVOCd6VJeHQA/7YcqyV0/88N6ysSVzRjJ9I9UA/xXpEsjvAgDTgH3wQYz5bmr7SPtVK2TsP2fQ2N9L4ukoUg==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.59.0.tgz", + "integrity": "sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==", "cpu": [ "arm64" ], @@ -632,10 +632,10 @@ "linux" ] }, - "node_modules/@rollup/rollup-linux-loongarch64-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loongarch64-gnu/-/rollup-linux-loongarch64-gnu-4.35.0.tgz", - "integrity": "sha512-5pMT5PzfgwcXEwOaSrqVsz/LvjDZt+vQ8RT/70yhPU06PTuq8WaHhfT1LW+cdD7mW6i/J5/XIkX/1tCAkh1W6g==", + "node_modules/@rollup/rollup-linux-loong64-gnu": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.59.0.tgz", + "integrity": "sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==", "cpu": [ "loong64" ], @@ -646,10 +646,38 @@ "linux" ] }, - "node_modules/@rollup/rollup-linux-powerpc64le-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-powerpc64le-gnu/-/rollup-linux-powerpc64le-gnu-4.35.0.tgz", - "integrity": "sha512-c+zkcvbhbXF98f4CtEIP1EBA/lCic5xB0lToneZYvMeKu5Kamq3O8gqrxiYYLzlZH6E3Aq+TSW86E4ay8iD8EA==", + "node_modules/@rollup/rollup-linux-loong64-musl": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.59.0.tgz", + "integrity": "sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-ppc64-gnu": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.59.0.tgz", + "integrity": "sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-ppc64-musl": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.59.0.tgz", + "integrity": "sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==", "cpu": [ "ppc64" ], @@ -661,9 +689,23 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.35.0.tgz", - "integrity": "sha512-s91fuAHdOwH/Tad2tzTtPX7UZyytHIRR6V4+2IGlV0Cej5rkG0R61SX4l4y9sh0JBibMiploZx3oHKPnQBKe4g==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.59.0.tgz", + "integrity": "sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-riscv64-musl": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.59.0.tgz", + "integrity": "sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==", "cpu": [ "riscv64" ], @@ -675,9 +717,9 @@ ] }, "node_modules/@rollup/rollup-linux-s390x-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.35.0.tgz", - "integrity": "sha512-hQRkPQPLYJZYGP+Hj4fR9dDBMIM7zrzJDWFEMPdTnTy95Ljnv0/4w/ixFw3pTBMEuuEuoqtBINYND4M7ujcuQw==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.59.0.tgz", + "integrity": "sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==", "cpu": [ "s390x" ], @@ -689,9 +731,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-gnu": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.35.0.tgz", - "integrity": "sha512-Pim1T8rXOri+0HmV4CdKSGrqcBWX0d1HoPnQ0uw0bdp1aP5SdQVNBy8LjYncvnLgu3fnnCt17xjWGd4cqh8/hA==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.59.0.tgz", + "integrity": "sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==", "cpu": [ "x64" ], @@ -703,9 +745,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-musl": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.35.0.tgz", - "integrity": "sha512-QysqXzYiDvQWfUiTm8XmJNO2zm9yC9P/2Gkrwg2dH9cxotQzunBHYr6jk4SujCTqnfGxduOmQcI7c2ryuW8XVg==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.59.0.tgz", + "integrity": "sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==", "cpu": [ "x64" ], @@ -716,10 +758,38 @@ "linux" ] }, + "node_modules/@rollup/rollup-openbsd-x64": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.59.0.tgz", + "integrity": "sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ] + }, + "node_modules/@rollup/rollup-openharmony-arm64": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.59.0.tgz", + "integrity": "sha512-tt9KBJqaqp5i5HUZzoafHZX8b5Q2Fe7UjYERADll83O4fGqJ49O1FsL6LpdzVFQcpwvnyd0i+K/VSwu/o/nWlA==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ] + }, "node_modules/@rollup/rollup-win32-arm64-msvc": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.35.0.tgz", - "integrity": "sha512-OUOlGqPkVJCdJETKOCEf1mw848ZyJ5w50/rZ/3IBQVdLfR5jk/6Sr5m3iO2tdPgwo0x7VcncYuOvMhBWZq8ayg==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.59.0.tgz", + "integrity": "sha512-V5B6mG7OrGTwnxaNUzZTDTjDS7F75PO1ae6MJYdiMu60sq0CqN5CVeVsbhPxalupvTX8gXVSU9gq+Rx1/hvu6A==", "cpu": [ "arm64" ], @@ -731,9 +801,9 @@ ] }, "node_modules/@rollup/rollup-win32-ia32-msvc": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.35.0.tgz", - "integrity": "sha512-2/lsgejMrtwQe44glq7AFFHLfJBPafpsTa6JvP2NGef/ifOa4KBoglVf7AKN7EV9o32evBPRqfg96fEHzWo5kw==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.59.0.tgz", + "integrity": "sha512-UKFMHPuM9R0iBegwzKF4y0C4J9u8C6MEJgFuXTBerMk7EJ92GFVFYBfOZaSGLu6COf7FxpQNqhNS4c4icUPqxA==", "cpu": [ "ia32" ], @@ -744,10 +814,24 @@ "win32" ] }, + "node_modules/@rollup/rollup-win32-x64-gnu": { + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.59.0.tgz", + "integrity": "sha512-laBkYlSS1n2L8fSo1thDNGrCTQMmxjYY5G0WFWjFFYZkKPjsMBsgJfGf4TLxXrF6RyhI60L8TMOjBMvXiTcxeA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, "node_modules/@rollup/rollup-win32-x64-msvc": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.35.0.tgz", - "integrity": "sha512-PIQeY5XDkrOysbQblSW7v3l1MDZzkTEzAfTPkj5VAu3FW8fS4ynyLg2sINp0fp3SjZ8xkRYpLqoKcYqAkhU1dw==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.59.0.tgz", + "integrity": "sha512-2HRCml6OztYXyJXAvdDXPKcawukWY2GpR5/nxKp4iBgiO3wcoEGkAaqctIbZcNB6KlUQBIqt8VYkNSj2397EfA==", "cpu": [ "x64" ], @@ -759,9 +843,9 @@ ] }, "node_modules/@types/estree": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.6.tgz", - "integrity": "sha512-AYnb1nQyY49te+VRAVgmzfcgjYS91mY5P0TKUDCLEM+gNnA+3T6rWITXRLYCpahpqSQbN5cE+gHpnPyXjHWxcw==", + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", + "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", "dev": true, "license": "MIT" }, @@ -1049,13 +1133,13 @@ } }, "node_modules/rollup": { - "version": "4.35.0", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.35.0.tgz", - "integrity": "sha512-kg6oI4g+vc41vePJyO6dHt/yl0Rz3Thv0kJeVQ3D1kS3E5XSuKbPc29G4IpT/Kv1KQwgHVcN+HtyS+HYLNSvQg==", + "version": "4.59.0", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.59.0.tgz", + "integrity": "sha512-2oMpl67a3zCH9H79LeMcbDhXW/UmWG/y2zuqnF2jQq5uq9TbM9TVyXvA4+t+ne2IIkBdrLpAaRQAvo7YI/Yyeg==", "dev": true, "license": "MIT", "dependencies": { - "@types/estree": "1.0.6" + "@types/estree": "1.0.8" }, "bin": { "rollup": "dist/bin/rollup" @@ -1065,25 +1149,31 @@ "npm": ">=8.0.0" }, "optionalDependencies": { - "@rollup/rollup-android-arm-eabi": "4.35.0", - "@rollup/rollup-android-arm64": "4.35.0", - "@rollup/rollup-darwin-arm64": "4.35.0", - "@rollup/rollup-darwin-x64": "4.35.0", - "@rollup/rollup-freebsd-arm64": "4.35.0", - "@rollup/rollup-freebsd-x64": "4.35.0", - "@rollup/rollup-linux-arm-gnueabihf": "4.35.0", - "@rollup/rollup-linux-arm-musleabihf": "4.35.0", - "@rollup/rollup-linux-arm64-gnu": "4.35.0", - "@rollup/rollup-linux-arm64-musl": "4.35.0", - "@rollup/rollup-linux-loongarch64-gnu": "4.35.0", - "@rollup/rollup-linux-powerpc64le-gnu": "4.35.0", - "@rollup/rollup-linux-riscv64-gnu": "4.35.0", - "@rollup/rollup-linux-s390x-gnu": "4.35.0", - "@rollup/rollup-linux-x64-gnu": "4.35.0", - "@rollup/rollup-linux-x64-musl": "4.35.0", - "@rollup/rollup-win32-arm64-msvc": "4.35.0", - "@rollup/rollup-win32-ia32-msvc": "4.35.0", - "@rollup/rollup-win32-x64-msvc": "4.35.0", + "@rollup/rollup-android-arm-eabi": "4.59.0", + "@rollup/rollup-android-arm64": "4.59.0", + "@rollup/rollup-darwin-arm64": "4.59.0", + "@rollup/rollup-darwin-x64": "4.59.0", + "@rollup/rollup-freebsd-arm64": "4.59.0", + "@rollup/rollup-freebsd-x64": "4.59.0", + "@rollup/rollup-linux-arm-gnueabihf": "4.59.0", + "@rollup/rollup-linux-arm-musleabihf": "4.59.0", + "@rollup/rollup-linux-arm64-gnu": "4.59.0", + "@rollup/rollup-linux-arm64-musl": "4.59.0", + "@rollup/rollup-linux-loong64-gnu": "4.59.0", + "@rollup/rollup-linux-loong64-musl": "4.59.0", + "@rollup/rollup-linux-ppc64-gnu": "4.59.0", + "@rollup/rollup-linux-ppc64-musl": "4.59.0", + "@rollup/rollup-linux-riscv64-gnu": "4.59.0", + "@rollup/rollup-linux-riscv64-musl": "4.59.0", + "@rollup/rollup-linux-s390x-gnu": "4.59.0", + "@rollup/rollup-linux-x64-gnu": "4.59.0", + "@rollup/rollup-linux-x64-musl": "4.59.0", + "@rollup/rollup-openbsd-x64": "4.59.0", + "@rollup/rollup-openharmony-arm64": "4.59.0", + "@rollup/rollup-win32-arm64-msvc": "4.59.0", + "@rollup/rollup-win32-ia32-msvc": "4.59.0", + "@rollup/rollup-win32-x64-gnu": "4.59.0", + "@rollup/rollup-win32-x64-msvc": "4.59.0", "fsevents": "~2.3.2" } }, diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index f00fad809968f..71e0e8561e110 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -78,22 +78,22 @@ class QuickGelu : public OpKernel { T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - if (alpha_ != 1.0f) { - // TODO: Consider vectorizing this scalar multiplication. - // It needs exposing a new API in MLAS to take in a scalar - // that will be used in the elementwise multiplication. - // Estimate the cost-benefit tradeoff before proceeding - // with that optimization. - for (int64_t i = 0; i < count; i++) { - p_output[i] = p_input[i] * alpha_; - } - - MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); - } else { - // SILU activation - this needs no `alpha_` scaling as `alpha_` will be 1.0f - MlasComputeLogistic(p_input, p_output, onnxruntime::narrow(count)); + if (alpha_ == 1.0f) { + MlasComputeSilu(p_input, p_output, onnxruntime::narrow(count)); + return; } + // TODO: Consider vectorizing this scalar multiplication. + // It needs exposing a new API in MLAS to take in a scalar + // that will be used in the elementwise multiplication. + // Estimate the cost-benefit tradeoff before proceeding + // with that optimization. + for (int64_t i = 0; i < count; i++) { + p_output[i] = p_input[i] * alpha_; + } + + MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); + MlasEltwiseMul(p_input, p_output, p_output, onnxruntime::narrow(count)); }, 0); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 56849995656f3..2b446c4b2601b 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1113,6 +1113,30 @@ MlasComputeErf( size_t N ); +// +// Note: The Input and Output buffers for MlasComputeGeluErf must not overlap. +// In-place operation (e.g., passing the same buffer for both parameters) is unsupported. +// +void +MLASCALL +MlasComputeGeluErf( + const float* Input, + float* Output, + size_t N + ); + +// +// Note: The Input and Output buffers for MlasComputeSilu must not overlap. +// In-place operation (e.g., passing the same buffer for both parameters) is unsupported. +// +void +MLASCALL +MlasComputeSilu( + const float* Input, + float* Output, + size_t N + ); + template void MLASCALL diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp new file mode 100644 index 0000000000000..dc25611652c77 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -0,0 +1,65 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + gelu.cpp + +Abstract: + + This module implements routines to compute the exact Gelu function. + +--*/ + +#include "mlasi.h" + +namespace { + +constexpr float kInvSqrt2 = 0.70710678118654752440f; + +} // namespace + + +void +MLASCALL +MlasGeluErfKernel( + const float* Input, + float* Output, + size_t N + ) +{ + // This kernel is not buffer alias safe because it is implemented in + // multiple passes: first scale Input into Output, then apply erf in place, + // and finally combine that intermediate with the original Input values. + // Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements). + for (size_t i = 0; i < N; ++i) { + Output[i] = Input[i] * kInvSqrt2; + } + + MlasComputeErf(Output, Output, N); + + for (size_t i = 0; i < N; ++i) { + Output[i] = 0.5f * Input[i] * (Output[i] + 1.0f); + } +} + +void +MLASCALL +MlasComputeGeluErf( + const float* Input, + float* Output, + size_t N + ) +{ +#if defined(MLAS_TARGET_AMD64) + // TODO: Add an intermediate fused AVX2/FMA3 GELU(erf) path on AMD64. + // Today the dispatch jumps from the generic multi-pass implementation to + // AVX512F, so non-AVX512 x64 machines fall back to the generic kernel. + GetMlasPlatform().GeluErfKernelRoutine(Input, Output, N); +#else + MlasGeluErfKernel(Input, Output, N); +#endif +} diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp new file mode 100644 index 0000000000000..4a9f3a100ed65 --- /dev/null +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -0,0 +1,219 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + gelu_avx512f.cpp + +Abstract: + + This module implements routines to compute exact Gelu with AVX512F + intrinsics. + +--*/ + +#include + +#include "mlasi.h" + +namespace { + +struct GeluAvx512Constants { + static constexpr int32_t SignBitMask = INT32_MIN; + static constexpr float InvSqrt2 = 0.70710678118654752440f; + static constexpr float Half = 0.5f; + static constexpr float One = 1.0f; + + static constexpr float ErfUpperAbsRange = 3.925f; + static constexpr float ErfSplitBoundary = 0.921875f; + static constexpr float ErfSMALL_P0 = -5.99104969e-4f; + static constexpr float ErfSMALL_P1 = 4.99339588e-3f; + static constexpr float ErfSMALL_P2 = -2.67667342e-2f; + static constexpr float ErfSMALL_P3 = 1.12818025e-1f; + static constexpr float ErfSMALL_P4 = -3.76124859e-1f; + static constexpr float ErfSMALL_P5_Minus_One = 1.28379151e-1f; + static constexpr float ErfBIG_P0 = 1.72948930e-5f; + static constexpr float ErfBIG_P1 = -3.83208680e-4f; + static constexpr float ErfBIG_P2 = 3.88393435e-3f; + static constexpr float ErfBIG_P3 = -2.42545605e-2f; + static constexpr float ErfBIG_P4 = 1.06777847e-1f; + static constexpr float ErfBIG_P5 = 6.34846687e-1f; + static constexpr float ErfBIG_P6_Minus_One = 1.28717512e-1f; + static constexpr float ErfOne = 1.0f; + static constexpr float ExpLowerRange = -88.3762626647949f; + static constexpr float ExpLog2Reciprocal = 1.44269504088896341f; + static constexpr float ExpLog2Hi = -6.93145752e-1f; + static constexpr float ExpLog2Lo = -1.42860677e-6f; + static constexpr float ExpP0 = 1.38319808e-3f; + static constexpr float ExpP1 = 8.37550033e-3f; + static constexpr float ExpP2 = 4.16689515e-2f; + static constexpr float ExpP3 = 1.66664466e-1f; + static constexpr float ExpP4 = 4.99999851e-1f; + static constexpr float ExpP5 = 1.0f; + static constexpr float ExpP6 = 1.0f; + static constexpr float ExpC = 1.25829120e+7f; +}; + +struct GeluAvx512BroadcastConstants { + const __m512 NegZero = _mm512_castsi512_ps(_mm512_set1_epi32(GeluAvx512Constants::SignBitMask)); + const __m512 Zero = _mm512_setzero_ps(); + const __m512 InvSqrt2 = _mm512_set1_ps(GeluAvx512Constants::InvSqrt2); + const __m512 Half = _mm512_set1_ps(GeluAvx512Constants::Half); + const __m512 One = _mm512_set1_ps(GeluAvx512Constants::One); + const __m512 ErfUpperAbsRange = _mm512_set1_ps(GeluAvx512Constants::ErfUpperAbsRange); + const __m512 ErfSplitBoundary = _mm512_set1_ps(GeluAvx512Constants::ErfSplitBoundary); + const __m512 ErfSmallP0 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P0); + const __m512 ErfSmallP1 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P1); + const __m512 ErfSmallP2 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P2); + const __m512 ErfSmallP3 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P3); + const __m512 ErfSmallP4 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P4); + const __m512 ErfSmallP5MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P5_Minus_One); + const __m512 ErfBigP0 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P0); + const __m512 ErfBigP1 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P1); + const __m512 ErfBigP2 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P2); + const __m512 ErfBigP3 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P3); + const __m512 ErfBigP4 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P4); + const __m512 ErfBigP5 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P5); + const __m512 ErfBigP6MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P6_Minus_One); + const __m512 ErfOne = _mm512_set1_ps(GeluAvx512Constants::ErfOne); + const __m512 ExpLowerRange = _mm512_set1_ps(GeluAvx512Constants::ExpLowerRange); + const __m512 ExpLog2Reciprocal = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Reciprocal); + const __m512 ExpLog2Hi = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Hi); + const __m512 ExpLog2Lo = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Lo); + const __m512 ExpP0 = _mm512_set1_ps(GeluAvx512Constants::ExpP0); + const __m512 ExpP1 = _mm512_set1_ps(GeluAvx512Constants::ExpP1); + const __m512 ExpP2 = _mm512_set1_ps(GeluAvx512Constants::ExpP2); + const __m512 ExpP3 = _mm512_set1_ps(GeluAvx512Constants::ExpP3); + const __m512 ExpP4 = _mm512_set1_ps(GeluAvx512Constants::ExpP4); + const __m512 ExpP5 = _mm512_set1_ps(GeluAvx512Constants::ExpP5); + const __m512 ExpP6 = _mm512_set1_ps(GeluAvx512Constants::ExpP6); + const __m512 ExpC = _mm512_set1_ps(GeluAvx512Constants::ExpC); +}; + +MLAS_FORCEINLINE __m512 +MlasGeluErfExpVectorAvx512( + __m512 Value, + const GeluAvx512BroadcastConstants& Constants + ) +{ + __m512 R = _mm512_fmadd_ps(Constants.ExpLog2Reciprocal, Value, Constants.ExpC); + R = _mm512_sub_ps(R, Constants.ExpC); + + __m512 Fx = _mm512_fmadd_ps(R, Constants.ExpLog2Hi, Value); + Fx = _mm512_fmadd_ps(R, Constants.ExpLog2Lo, Fx); + + __m512 Y = Constants.ExpP0; + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP1); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP2); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP3); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP4); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP5); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP6); + Y = _mm512_scalef_ps(Y, R); + + return Y; +} + +MLAS_FORCEINLINE __m512 +MlasGeluErfAvx512( + __m512 Value, + const GeluAvx512BroadcastConstants& Constants + ) +{ + const __m512 SignMask = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(Value), _mm512_castps_si512(Constants.NegZero))); + __m512 AbsValue = _mm512_castsi512_ps(_mm512_andnot_si512(_mm512_castps_si512(Constants.NegZero), _mm512_castps_si512(Value))); + AbsValue = _mm512_min_ps(Constants.ErfUpperAbsRange, AbsValue); + + const __m512 SquareValue = _mm512_mul_ps(AbsValue, AbsValue); + + __m512 SmallResult = Constants.ErfSmallP0; + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP1); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP2); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP3); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP4); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP5MinusOne); + SmallResult = _mm512_fmadd_ps(SmallResult, AbsValue, AbsValue); + + const __mmask16 SplitMask = _mm512_cmp_ps_mask(AbsValue, Constants.ErfSplitBoundary, _CMP_GT_OQ); + const __m512 BigInput = _mm512_mask_blend_ps(SplitMask, Constants.Zero, AbsValue); + + __m512 BigResult = Constants.ErfBigP0; + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP1); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP2); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP3); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP4); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP5); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP6MinusOne); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, BigInput); + + BigResult = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(BigResult), _mm512_castps_si512(Constants.NegZero))); + BigResult = _mm512_max_ps(Constants.ExpLowerRange, BigResult); + BigResult = _mm512_sub_ps(Constants.ErfOne, MlasGeluErfExpVectorAvx512(BigResult, Constants)); + + __m512 Result = _mm512_mask_blend_ps(SplitMask, SmallResult, BigResult); + Result = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(Result), _mm512_castps_si512(SignMask))); + return Result; +} + +MLAS_FORCEINLINE __m512 +MlasComputeGeluVectorExactAvx512( + __m512 X, + const GeluAvx512BroadcastConstants& Constants + ) +{ + const __m512 ErfInput = _mm512_mul_ps(X, Constants.InvSqrt2); + const __m512 ErfValue = MlasGeluErfAvx512(ErfInput, Constants); + __m512 Result = _mm512_mul_ps(_mm512_mul_ps(Constants.Half, X), _mm512_add_ps(ErfValue, Constants.One)); + + // Preserve NaN payload/sign behavior explicitly because the erf + // approximation uses min/max style range limiting that is not guaranteed to + // preserve NaNs the same way as the existing MLAS GELU semantics. + const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); + Result = _mm512_mask_mov_ps(Result, NaNMask, X); + + return Result; +} + +void +MlasGeluErfKernelAvx512FExactImpl( + const float* Input, + float* Output, + size_t N + ) +{ + const GeluAvx512BroadcastConstants Constants; + while (N >= 16) { + const __m512 X = _mm512_loadu_ps(Input); + const __m512 Result = MlasComputeGeluVectorExactAvx512(X, Constants); + + _mm512_storeu_ps(Output, Result); + + Input += 16; + Output += 16; + N -= 16; + } + + if (N > 0) { + const __mmask16 TailMask = __mmask16((1u << static_cast(N)) - 1u); + const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input); + const __m512 Result = MlasComputeGeluVectorExactAvx512(X, Constants); + + _mm512_mask_storeu_ps(Output, TailMask, Result); + } +} + +} // namespace + +void +MLASCALL +MlasGeluErfKernelAvx512F( + const float* Input, + float* Output, + size_t N + ) +{ + MlasGeluErfKernelAvx512FExactImpl(Input, Output, N); +} diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp new file mode 100644 index 0000000000000..7e8424d94827a --- /dev/null +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -0,0 +1,140 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + silu_avx512f.cpp + +Abstract: + + This module implements routines to compute the SiLU function with AVX512F + intrinsics. + +--*/ + +#include "mlasi.h" + +namespace { + +struct SiluAvx512Constants { + static constexpr float LogisticLowerRange = -18.0f; + static constexpr float LogisticUpperRange = 18.0f; + static constexpr float Alpha9 = 4.37031012579801e-11f; + static constexpr float Alpha7 = 1.15627324459942e-07f; + static constexpr float Alpha5 = 6.08574864600143e-05f; + static constexpr float Alpha3 = 8.51377133304701e-03f; + static constexpr float Alpha1 = 2.48287947061529e-01f; + static constexpr float Beta10 = 6.10247389755681e-13f; + static constexpr float Beta8 = 5.76102136993427e-09f; + static constexpr float Beta6 = 6.29106785017040e-06f; + static constexpr float Beta4 = 1.70198817374094e-03f; + static constexpr float Beta2 = 1.16817656904453e-01f; + static constexpr float Beta0 = 9.93151921023180e-01f; + static constexpr float OneHalf = 0.5f; +}; + +struct SiluAvx512BroadcastConstants { + const __m512 LogisticLowerRange = _mm512_set1_ps(SiluAvx512Constants::LogisticLowerRange); + const __m512 LogisticUpperRange = _mm512_set1_ps(SiluAvx512Constants::LogisticUpperRange); + const __m512 Alpha9 = _mm512_set1_ps(SiluAvx512Constants::Alpha9); + const __m512 Alpha7 = _mm512_set1_ps(SiluAvx512Constants::Alpha7); + const __m512 Alpha5 = _mm512_set1_ps(SiluAvx512Constants::Alpha5); + const __m512 Alpha3 = _mm512_set1_ps(SiluAvx512Constants::Alpha3); + const __m512 Alpha1 = _mm512_set1_ps(SiluAvx512Constants::Alpha1); + const __m512 Beta10 = _mm512_set1_ps(SiluAvx512Constants::Beta10); + const __m512 Beta8 = _mm512_set1_ps(SiluAvx512Constants::Beta8); + const __m512 Beta6 = _mm512_set1_ps(SiluAvx512Constants::Beta6); + const __m512 Beta4 = _mm512_set1_ps(SiluAvx512Constants::Beta4); + const __m512 Beta2 = _mm512_set1_ps(SiluAvx512Constants::Beta2); + const __m512 Beta0 = _mm512_set1_ps(SiluAvx512Constants::Beta0); + const __m512 OneHalf = _mm512_set1_ps(SiluAvx512Constants::OneHalf); + const __m512 Zero = _mm512_setzero_ps(); + const __m512 One = _mm512_set1_ps(1.0f); +}; + +MLAS_FORCEINLINE __m512 +MlasLogisticApproxAvx512( + __m512 Value, + const SiluAvx512BroadcastConstants& Constants + ) +{ + // Mirror MlasComputeLogistic by evaluating the same clamped rational + // approximation in-register and then multiplying by x for SiLU. + const __m512 ClampedValue = _mm512_max_ps(_mm512_min_ps(Value, Constants.LogisticUpperRange), Constants.LogisticLowerRange); + const __m512 ValueSquared = _mm512_mul_ps(ClampedValue, ClampedValue); + + __m512 P = _mm512_fmadd_ps(ValueSquared, Constants.Alpha9, Constants.Alpha7); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha5); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha3); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha1); + P = _mm512_mul_ps(P, ClampedValue); + + __m512 Q = _mm512_fmadd_ps(ValueSquared, Constants.Beta10, Constants.Beta8); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta6); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta4); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta2); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta0); + + __m512 Logistic = _mm512_add_ps(_mm512_div_ps(P, Q), Constants.OneHalf); + Logistic = _mm512_min_ps(_mm512_max_ps(Logistic, Constants.Zero), Constants.One); + + return Logistic; +} + +MLAS_FORCEINLINE __m512 +MlasComputeSiluVectorAvx512( + __m512 X, + const SiluAvx512BroadcastConstants& Constants + ) +{ + __m512 Result = _mm512_mul_ps(X, MlasLogisticApproxAvx512(X, Constants)); + + // Preserve NaN payload/sign behavior explicitly because the clamped + // logistic approximation uses min/max operations that do not reliably + // propagate NaNs the same way as the existing MLAS SiLU semantics. + const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); + Result = _mm512_mask_mov_ps(Result, NaNMask, X); + + return Result; +} + +} // namespace + +void +MLASCALL +MlasSiluKernelAvx512F( + const float* Input, + float* Output, + size_t N + ) +{ + const SiluAvx512BroadcastConstants Constants; + size_t Offset = 0; + + while (Offset + 32 <= N) { + const __m512 X0 = _mm512_loadu_ps(Input + Offset); + const __m512 X1 = _mm512_loadu_ps(Input + Offset + 16); + const __m512 Result0 = MlasComputeSiluVectorAvx512(X0, Constants); + const __m512 Result1 = MlasComputeSiluVectorAvx512(X1, Constants); + _mm512_storeu_ps(Output + Offset, Result0); + _mm512_storeu_ps(Output + Offset + 16, Result1); + Offset += 32; + } + + while (Offset + 16 <= N) { + const __m512 X = _mm512_loadu_ps(Input + Offset); + const __m512 Result = MlasComputeSiluVectorAvx512(X, Constants); + _mm512_storeu_ps(Output + Offset, Result); + Offset += 16; + } + + if (Offset < N) { + const __mmask16 TailMask = static_cast<__mmask16>((1u << (N - Offset)) - 1u); + const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input + Offset); + const __m512 Result = MlasComputeSiluVectorAvx512(X, Constants); + _mm512_mask_storeu_ps(Output + Offset, TailMask, Result); + } +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 954849fe90049..0dab8e41f25cd 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1096,6 +1096,8 @@ extern "C" { #endif MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernel; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluErfKernel; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasLogisticKernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasTanhKernel; @@ -1126,6 +1128,8 @@ extern "C" { MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8KernelAvx2; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelAvx512F; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelAvx512F; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluErfKernelAvx512F; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernelAvx512F; #endif MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32Kernel; @@ -1477,6 +1481,8 @@ struct MLAS_PLATFORM { MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; #endif #if defined(MLAS_TARGET_AMD64) + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluErfKernelRoutine; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* SiluKernelRoutine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine; MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* TransposePackB16x4Routine; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index ac3761d63bd20..eccde79848e61 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -283,7 +283,9 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelSse; this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelSse; this->ComputeExpF32Kernel = MlasComputeExpF32Kernel; + this->GeluErfKernelRoutine = MlasGeluErfKernel; this->LogisticKernelRoutine = MlasLogisticKernel; + this->SiluKernelRoutine = MlasSiluKernel; this->TanhKernelRoutine = MlasTanhKernel; this->ErfKernelRoutine = MlasErfKernel; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32Kernel; @@ -459,7 +461,8 @@ Return Value: // if (((Cpuid7[1] & 0x10000) != 0) && ((xcr0 & 0xE0) == 0xE0)) { - + this->GeluErfKernelRoutine = MlasGeluErfKernelAvx512F; + this->SiluKernelRoutine = MlasSiluKernelAvx512F; this->GemmFloatKernel = MlasGemmFloatKernelAvx512F; this->GemmDoubleKernel = MlasGemmDoubleKernelAvx512F; this->ConvNchwFloatKernel = MlasConvNchwFloatKernelAvx512F; diff --git a/onnxruntime/core/mlas/lib/silu.cpp b/onnxruntime/core/mlas/lib/silu.cpp new file mode 100644 index 0000000000000..96686e4bdf1da --- /dev/null +++ b/onnxruntime/core/mlas/lib/silu.cpp @@ -0,0 +1,51 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + silu.cpp + +Abstract: + + This module implements routines to compute the SiLU function. + +--*/ + +#include "mlasi.h" + +void +MLASCALL +MlasSiluKernel( + const float* Input, + float* Output, + size_t N + ) +{ + // This kernel is not buffer alias safe because it is implemented in two + // passes: first compute logistic(Input) into Output, then multiply that + // intermediate by the original Input values. Callers must guarantee that + // Input and Output do not overlap (see mlas.h for aliasing requirements). + MlasComputeLogistic(Input, Output, N); + MlasEltwiseMul(Input, Output, Output, N); +} + +void +MLASCALL +MlasComputeSilu( + const float* Input, + float* Output, + size_t N + ) +{ +#if defined(MLAS_TARGET_AMD64) + // TODO: Add an intermediate fused AVX2/FMA3 SiLU path on AMD64. Today the + // dispatch jumps from the generic two-pass implementation to AVX512F, so + // non-AVX512 x64 machines fall back to the generic kernel. + GetMlasPlatform().SiluKernelRoutine(Input, Output, N); +#else + MlasSiluKernel(Input, Output, N); +#endif +} diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index ac712084012a4..9ed1d5e9e84fa 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -355,6 +355,10 @@ InlinedVector> GenerateTransformers( ParseStringWithClassicLocale( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "4")); + const int64_t qdq_matmulnbits_block_size = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsBlockSize, + "0")); #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -371,7 +375,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, SatApplyContextVariant{}, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + qdq_matmulnbits_block_size)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -513,6 +518,10 @@ InlinedVector> GenerateTransformersForMinimalB ParseStringWithClassicLocale( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "4")); + const int64_t qdq_matmulnbits_block_size = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsBlockSize, + "0")); // runtime optimizations only support CPU EP now const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; @@ -520,7 +529,8 @@ InlinedVector> GenerateTransformersForMinimalB transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + qdq_matmulnbits_block_size)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 4e03450077718..b9366ff0abae8 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -881,12 +881,21 @@ void NchwcTransformerImpl::TransformActivation(Node& node) { const bool can_fuse_activation = (node.OpType() == "Relu") || (node.OpType() == "Sigmoid") || - (node.OpType() == "Tanh"); + (node.OpType() == "Tanh") || + (node.OpType() == "HardSigmoid"); if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) && can_fuse_activation && (nchwc_input->starting_original_uses_ == 1) && (graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) { nchwc_node.AddAttribute("activation", node.OpType()); + if (node.OpType() == "HardSigmoid") { + const auto* alpha_attr = graph_utils::GetNodeAttribute(node, "alpha"); + const auto* beta_attr = graph_utils::GetNodeAttribute(node, "beta"); + InlinedVector activation_params{ + alpha_attr == nullptr ? 0.2f : alpha_attr->f(), + beta_attr == nullptr ? 0.5f : beta_attr->f()}; + nchwc_node.AddAttribute("activation_params", activation_params); + } FuseNchwcArgument(node, *nchwc_input); removed_nodes_.push_front(node.Index()); } else { @@ -1265,8 +1274,10 @@ void NchwcTransformerImpl::Transform(Node& node) { } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) { TransformConcat(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6, 22}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {20}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {1}, kMSDomain) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain)) { TransformActivation(node); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index fdc0818e8437b..b9d7e898157bd 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -42,6 +42,67 @@ bool IsDQWeightSigned(int32_t dt_weight) { dt_weight == TensorProto::INT8; } +// Compute the effective block_size for per-tensor/per-channel DQ nodes that lack a block_size attribute. +// session_block_size: 0 = default (32), positive = explicit, -1 = min-padding heuristic. +int64_t ComputeEffectiveBlockSize(int64_t session_block_size, int64_t K) { + // MatMulNBits CPU kernel currently only supports block_size in [16, 256] correctly. + constexpr int64_t kMinBlockSize = 16; + constexpr int64_t kMaxBlockSize = 256; + + if (session_block_size > 0) { + // Explicit block_size — must be power-of-2 and within [kMinBlockSize, kMaxBlockSize]. + ORT_ENFORCE(session_block_size >= kMinBlockSize && + ((session_block_size & (session_block_size - 1)) == 0), + "Explicit qdq_matmulnbits_block_size must be a power-of-2 and >= ", + kMinBlockSize, ", got: ", session_block_size); + ORT_ENFORCE(session_block_size <= kMaxBlockSize, + "Explicit qdq_matmulnbits_block_size must be <= ", + kMaxBlockSize, ", got: ", session_block_size); + return session_block_size; + } + + if (session_block_size == -1) { + // Heuristic: largest power-of-2 <= min(K, kMaxBlockSize) that minimizes padding. + // Capped at kMaxBlockSize because CPU EP only supports block_size up to kMaxBlockSize correctly. + // We want ceil(K / B) * B - K to be minimized (least wasted padding). + int64_t best_bs = kMinBlockSize; + int64_t best_padding = (((K + (kMinBlockSize - 1)) / kMinBlockSize) * kMinBlockSize) - K; + for (int64_t bs = kMinBlockSize * 2; bs <= std::min(K, kMaxBlockSize); bs *= 2) { + int64_t padding = (((K + bs - 1) / bs) * bs) - K; + if (padding <= best_padding) { + best_padding = padding; + best_bs = bs; + } + } + return best_bs; + } + + // Default (session_block_size == 0): use 32 + return 32; +} + +// Get the DQ block_size: from the attribute if blockwise, or computed for per-tensor/per-channel. +int64_t GetEffectiveBlockSize(const Node& dq_node, int64_t block_size_for_non_blockwise) { + const auto& dq_attrs = dq_node.GetAttributes(); + const auto bs_iter = dq_attrs.find("block_size"); + if (bs_iter != dq_attrs.end() && bs_iter->second.i() > 0) { + return bs_iter->second.i(); + } + + // Derive K from the weight input shape if available. Shape information may be missing even + // when the weight is a constant initializer, so guard against nullptrs / unknown dims. + int64_t K = 32; // reasonable default consistent with ComputeEffectiveBlockSize default + const auto* weight_arg = dq_node.InputDefs()[0]; + if (weight_arg != nullptr) { + const auto* shape = weight_arg->Shape(); + if (shape != nullptr && shape->dim_size() > 0 && shape->dim(0).has_dim_value()) { + K = static_cast(shape->dim(0).dim_value()); + } + } + + return ComputeEffectiveBlockSize(block_size_for_non_blockwise, K); +} + // Holds transposed weight/scale/zp tensors and their TensorProtos for MatMulNBits. // Used by DQMatMulToMatMulNBitsAction. struct TransposedQuantizedTensors { @@ -56,16 +117,17 @@ struct TransposedQuantizedTensors { // Transpose DQ weight/scale/zp tensors from column-wise layout to MatMulNBits layout via MLAS. // default_zp_name_prefix: prefix for auto-generated zero-point name when unsigned type has no explicit zp. +// effective_block_size: the block_size to use for MatMulNBits (may differ from DQ's block_size for per-tensor/per-channel). Status TransposeDQWeightsForMatMulNBits( Graph& graph, const Node& dq_node, const std::string& default_zp_name_prefix, concurrency::ThreadPool* intra_op_thread_pool, + int64_t effective_block_size, TransposedQuantizedTensors& result) { const auto* weight_arg = dq_node.InputDefs()[0]; const auto* scale_arg = dq_node.InputDefs()[1]; const auto* zp_arg = dq_node.InputDefs().size() > 2 ? dq_node.InputDefs()[2] : nullptr; - const auto& attrs = dq_node.GetAttributes(); const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), @@ -78,9 +140,11 @@ Status TransposeDQWeightsForMatMulNBits( graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); } - auto K = weight_arg->Shape()->dim(0).dim_value(); - auto N = weight_arg->Shape()->dim(1).dim_value(); - auto block_size = attrs.at("block_size").i(); + ORT_RETURN_IF_NOT(weight_tensor_proto->dims_size() >= 2, + "Weight tensor for node ", dq_node.Name(), " must be at least 2D."); + auto K = weight_tensor_proto->dims(0); + auto N = weight_tensor_proto->dims(1); + auto block_size = effective_block_size; int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); auto bits = DQWeightBits(dt_weight); auto quant_num = (K + block_size - 1) / block_size; @@ -94,8 +158,100 @@ Status TransposeDQWeightsForMatMulNBits( std::optional zp_src; auto cpu_allocator = CPUAllocator::DefaultInstance(); + // Determine if scale/zp need expansion from per-tensor/per-channel to blockwise [quant_num, N]. + const bool is_blockwise = (scale_tensor_proto->dims_size() == 2); + std::optional expanded_scale; + std::optional expanded_zp; + + if (!is_blockwise) { + // Expand scale to [quant_num, N] + expanded_scale.emplace(scale_type, TensorShape{quant_num, N}, cpu_allocator); + bool is_per_tensor = (scale_tensor_proto->dims_size() == 0); + + auto expand_scale = [&](auto* src_data, auto* dst_data) { + if (is_per_tensor) { + auto val = src_data[0]; + for (int64_t i = 0; i < quant_num * N; ++i) { + dst_data[i] = val; + } + } else { + // Per-channel: scale shape [N], replicate across quant_num blocks + for (int64_t b = 0; b < quant_num; ++b) { + for (int64_t n = 0; n < N; ++n) { + dst_data[b * N + n] = src_data[n]; + } + } + } + }; + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + expand_scale(scale_src.data(), expanded_scale->MutableData()); + } else { + expand_scale(scale_src.data(), expanded_scale->MutableData()); + } + + // Expand zp if present + if (zp_tensor_proto) { + zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); + // Allocate as uint8 with enough bytes to hold quant_num*N packed sub-byte elements. + int64_t expanded_zp_bytes = (quant_num * N * bits + 7) / 8; + expanded_zp.emplace(uint8_type, TensorShape{expanded_zp_bytes}, cpu_allocator); + + // For sub-byte types, the zp is packed in bytes. We need to expand element-wise. + // For 8-bit, each byte is one element. For 4-bit, 2 elements per byte. For 2-bit, 4 elements per byte. + const uint8_t* zp_bytes = zp_src->DataAsByteSpan().data(); + uint8_t* dst_zp_bytes = expanded_zp->MutableData(); + + auto get_element = [bits](const uint8_t* data, int64_t idx) -> uint8_t { + if (bits == 8) return data[idx]; + if (bits == 4) { + uint8_t byte = data[idx / 2]; + return (idx % 2 == 0) ? (byte & 0x0F) : (byte >> 4); + } + // bits == 2 + uint8_t byte = data[idx / 4]; + int shift = static_cast((idx % 4) * 2); + return (byte >> shift) & 0x03; + }; + + auto set_element = [bits](uint8_t* data, int64_t idx, uint8_t val) { + if (bits == 8) { + data[idx] = val; + return; + } + if (bits == 4) { + int64_t byte_idx = idx / 2; + if (idx % 2 == 0) { + data[byte_idx] = (data[byte_idx] & 0xF0) | (val & 0x0F); + } else { + data[byte_idx] = (data[byte_idx] & 0x0F) | ((val & 0x0F) << 4); + } + return; + } + // bits == 2 + int64_t byte_idx = idx / 4; + int shift = static_cast((idx % 4) * 2); + data[byte_idx] = (data[byte_idx] & ~(0x03 << shift)) | ((val & 0x03) << shift); + }; + + // Initialize expanded zp to 0 + memset(dst_zp_bytes, 0, expanded_zp->SizeInBytes()); + + for (int64_t b = 0; b < quant_num; ++b) { + for (int64_t n = 0; n < N; ++n) { + int64_t src_idx = is_per_tensor ? 0 : n; + uint8_t val = get_element(zp_bytes, src_idx); + set_element(dst_zp_bytes, b * N + n, val); + } + } + } + } + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); result.weight = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator); + // Zero-initialize: MLAS 4-bit transpose does not zero-pad when K < block_size, + // leaving uninitialized bytes in the last block's padding region. + memset(result.weight.MutableDataRaw(), 0, result.weight.SizeInBytes()); auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); auto scale_size = (TensorShape{N, quant_num}).Size(); @@ -104,7 +260,13 @@ Status TransposeDQWeightsForMatMulNBits( std::string zp_dst_name; auto zp_size = (TensorShape{N, (quant_num * bits + 7) / 8}).Size(); - if (zp_tensor_proto) { + if (!is_blockwise && expanded_zp.has_value()) { + // Per-tensor/per-channel path with expanded zero-point + zp_dst_name = graph.GenerateNodeArgName( + (zp_arg ? zp_arg->Name() : default_zp_name_prefix + "_zero_point") + "_T"); + result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + } else if (zp_tensor_proto) { + // Blockwise path with explicit zero-point zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); @@ -116,10 +278,15 @@ Status TransposeDQWeightsForMatMulNBits( // Dispatch MLAS transpose based on scale type, bits, and signedness. auto transpose = [&](auto* scale_data, auto* scale_dst_data) { - using ScaleType = std::remove_pointer_t; + using ScaleType = std::remove_const_t>; bool is_signed = IsDQWeightSigned(dt_weight); const uint8_t* src_w = weight_src.DataAsByteSpan().data(); - const uint8_t* src_zp = zp_src ? zp_src->DataAsByteSpan().data() : nullptr; + const uint8_t* src_zp = nullptr; + if (expanded_zp.has_value()) { + src_zp = expanded_zp->Data(); + } else if (zp_src.has_value()) { + src_zp = zp_src->DataAsByteSpan().data(); + } uint8_t* dst_w = result.weight.MutableData(); uint8_t* dst_zp = result.zero_point ? result.zero_point->MutableData() : nullptr; int K_int = static_cast(K); @@ -148,9 +315,11 @@ Status TransposeDQWeightsForMatMulNBits( }; if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - transpose(scale_src.data(), result.scale.MutableData()); + const float* s_data = expanded_scale.has_value() ? expanded_scale->Data() : scale_src.data(); + transpose(s_data, result.scale.MutableData()); } else { - transpose(scale_src.data(), result.scale.MutableData()); + const MLFloat16* s_data = expanded_scale.has_value() ? expanded_scale->Data() : scale_src.data(); + transpose(s_data, result.scale.MutableData()); } result.weight_proto = utils::TensorToTensorProto(result.weight, weight_dst_name, true); @@ -430,7 +599,8 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) + concurrency::ThreadPool* intra_op_thread_pool, + int64_t block_size_for_non_blockwise) : accuracy_level_{accuracy_level}, domain_{kMSDomain}, op_type_{"MatMulNBits"}, @@ -440,7 +610,8 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()}, - intra_op_thread_pool_{intra_op_thread_pool} { + intra_op_thread_pool_{intra_op_thread_pool}, + block_size_for_non_blockwise_{block_size_for_non_blockwise} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } @@ -449,15 +620,17 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) NodeAttributes extra_attributes; const auto* dq_node = runtime_state.selected_nodes.Input(0); - auto& attrs = dq_node->GetAttributes(); const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); + ORT_ENFORCE(weight_shape != nullptr && weight_shape->dim_size() >= 2, + "Weight shape unavailable for DQ node ", dq_node->Name()); utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); int32_t dt_weight = dq_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); utils::SetNodeAttribute(utils::MakeAttribute("bits", DQWeightBits(dt_weight)), extra_attributes); - utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); + int64_t effective_bs = GetEffectiveBlockSize(*dq_node, block_size_for_non_blockwise_); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", effective_bs), extra_attributes); return extra_attributes; } @@ -467,9 +640,11 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, Node& replacement_node) const { const auto* dq_node = selected_nodes.Input(0); + int64_t effective_bs = GetEffectiveBlockSize(*dq_node, block_size_for_non_blockwise_); + TransposedQuantizedTensors transposed; ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits( - graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, transposed)); + graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, effective_bs, transposed)); auto& input_defs = replacement_node.MutableInputDefs(); input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight))); @@ -483,6 +658,31 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, replacement_node.MutableInputArgsCount().push_back(1); } + // If the target was Gemm, strip Gemm-specific attributes from the replacement MatMulNBits node + // and wire the bias (if present) to MatMulNBits input 5. + const auto& target = selected_nodes.Target(); + if (target.OpType() == "Gemm") { + replacement_node.ClearAttribute("alpha"); + replacement_node.ClearAttribute("beta"); + replacement_node.ClearAttribute("transA"); + replacement_node.ClearAttribute("transB"); + + // Wire Gemm bias to MatMulNBits input 5 (bias slot). + // The bias can be a direct float tensor or the output of a DQ node. + const auto& target_inputs = target.InputDefs(); + if (target_inputs.size() > 2 && target_inputs[2] && target_inputs[2]->Exists()) { + // MatMulNBits input layout: 0:A, 1:B, 2:scales, 3:zp(opt), 4:g_idx(opt), 5:bias(opt) + // Pad with empty NodeArgs up to position 5. + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + while (input_defs.size() < 5) { + input_defs.push_back(&empty_arg); + replacement_node.MutableInputArgsCount().push_back(1); + } + input_defs.push_back(const_cast(target_inputs[2])); + replacement_node.MutableInputArgsCount().push_back(1); + } + } + return Status::OK(); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 02a8353707599..f0b1e17a7ffe0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -86,7 +86,8 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool); + concurrency::ThreadPool* intra_op_thread_pool, + int64_t block_size_for_non_blockwise = 0); private: std::string OpType(const RuntimeState&) const override { return op_type_; } @@ -105,6 +106,7 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; + const int64_t block_size_for_non_blockwise_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 8cab6911646f2..c88ae9b8c4782 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -296,15 +296,19 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + int64_t qdq_matmulnbits_block_size) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is 2/4/8-bit int (int2/uint2, int4/uint4, int8/uint8). DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. + // Also supports per-tensor and per-channel (axis=1) quantized DQ weights by expanding + // scales/zero-points to blockwise format using qdq_matmulnbits_block_size. const std::string action_name{"DQMatMulToMatMulNBits"}; std::unique_ptr action = std::make_unique(qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + qdq_matmulnbits_block_size); #if !defined(ORT_MINIMAL_BUILD) // Include "" (empty string) to match nodes not yet assigned to an EP. @@ -315,7 +319,8 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi std::vector providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider, ""}; std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, - {{"MatMul", {}}}, + {{"MatMul", {}}, + {"Gemm", {}}}, std::move(selector), std::move(action)); @@ -370,7 +375,8 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { SelectorActionRegistry CreateSelectorActionRegistry( bool is_int8_allowed, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + int64_t qdq_matmulnbits_block_size) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -384,7 +390,8 @@ SelectorActionRegistry CreateSelectorActionRegistry( WhereQDQRules(qdq_selector_action_registry); DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + qdq_matmulnbits_block_size); return qdq_selector_action_registry; } @@ -395,11 +402,13 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( bool is_int8_allowed, const SatApplyContextVariant& apply_context, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) + concurrency::ThreadPool* intra_op_thread_pool, + int64_t qdq_matmulnbits_block_size) : SelectorActionTransformer{ "QDQSelectorActionTransformer", CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool), + intra_op_thread_pool, + qdq_matmulnbits_block_size), apply_context, // this transformer is compatible with CPU, DML, ACL and CUDA EP. // There is further EP control on the rule level. diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index dce1cd44fd3ea..8294c839cfe42 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -29,7 +29,8 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer { QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}, int64_t qdq_matmulnbits_accuracy_level = 4, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + int64_t qdq_matmulnbits_block_size = 0); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 8a00fe11ff3fd..ef9e1b0cad490 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -6,6 +6,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/graph/graph.h" +#include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" @@ -558,11 +559,14 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& } } -// Validate that a DQ node has the correct structure for MatMulNBits fusion: -// - weight type is 2/4/8-bit int, scale type is float or float16 -// - blockwise quantization along axis 0, block_size is power-of-2 and >= 16 -// - weight/scale/zp are constant initializers with rank 2 and consistent shapes -static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq_node) { +// Validate that a DQ node has the correct structure for MatMulNBits fusion. +// Supports three quantization granularities: +// - Blockwise: axis=0, block_size >= 16 and power-of-2, scale/zp rank 2 +// - Per-tensor: scale is scalar (rank 0), no block_size attribute +// - Per-channel (axis=1): scale is 1D with shape [N], weight is 2D [K,N], no block_size attribute +// In all cases: weight type is 2/4/8-bit int, scale type is float or float16, +// weight/scale/zp are constant initializers. +static bool ValidateDQForMatMulNBits(const Graph& graph, const Node& dq_node) { const auto* weight_arg = dq_node.InputDefs()[0]; const auto* scale_arg = dq_node.InputDefs()[1]; const auto* zero_point_arg = dq_node.InputDefs().size() == 3 ? dq_node.InputDefs()[2] : nullptr; @@ -578,22 +582,6 @@ static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq return false; } - // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 - const auto& dq_attrs = dq_node.GetAttributes(); - if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { - return false; - } - - const auto a_iter = dq_attrs.find("block_size"); - if (a_iter == dq_attrs.end()) { - return false; - } - - auto block_size = a_iter->second.i(); - if (block_size < 16 || ((block_size - 1) & block_size)) { - return false; - } - // weight, scale and zero points (if exists) must be constants const auto* weight_tensor_proto = graph.GetConstantInitializer(weight_arg->Name(), true); const auto* scale_tensor_proto = graph.GetConstantInitializer(scale_arg->Name(), true); @@ -607,18 +595,124 @@ static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq return false; } - // weight, scale and zero points (if exists) must have the rank 2 - if (weight_tensor_proto->dims_size() != 2 || scale_tensor_proto->dims_size() != 2 || - (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + // weight must be rank 2 + if (weight_tensor_proto->dims_size() != 2) { return false; } - // check weight, scale and zero points (if exists) shapes - if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || - weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || - (zp_tensor_proto && (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || - zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + const auto& dq_attrs = dq_node.GetAttributes(); + const auto block_size_iter = dq_attrs.find("block_size"); + const bool has_block_size = block_size_iter != dq_attrs.end() && block_size_iter->second.i() > 0; + + if (has_block_size) { + // --- Blockwise path (existing logic) --- + if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { + return false; + } + + auto block_size = block_size_iter->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) { + return false; + } + + if (scale_tensor_proto->dims_size() != 2 || + (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + return false; + } + + if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || + weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || + (zp_tensor_proto && (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || + zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + return false; + } + } else { + // --- Per-tensor or per-channel path --- + int scale_rank = scale_tensor_proto->dims_size(); + auto N = weight_tensor_proto->dims()[1]; + + if (scale_rank == 0) { + // Per-tensor: scalar scale, optional scalar zp + if (zp_tensor_proto && zp_tensor_proto->dims_size() != 0) { + return false; + } + } else if (scale_rank == 1 && scale_tensor_proto->dims()[0] == N) { + // Per-channel (axis=1): scale shape [N], axis must be 1 + const auto a_iter = dq_attrs.find("axis"); + // DQ default axis is 1, so absent axis is OK + if (a_iter != dq_attrs.end() && a_iter->second.i() != 1) { + return false; + } + if (zp_tensor_proto && (zp_tensor_proto->dims_size() != 1 || zp_tensor_proto->dims()[0] != N)) { + return false; + } + } else { + // Unsupported quantization granularity + return false; + } + } + + return true; +} + +// Validate Gemm attributes for DQ->MatMulNBits fusion. +// Gemm must be equivalent to MatMul: alpha=1, transA=0, transB=0. +// If bias exists, beta must be 1 and bias shape must be [N]. +static bool ValidateGemmForDQMatMulNBits(const Graph& graph, const Node& gemm_node, const Node& weight_dq_node) { + if (const auto* alpha_attr = graph_utils::GetNodeAttribute(gemm_node, "alpha"); + alpha_attr && std::abs(alpha_attr->f() - 1.0f) > 1e-6f) + return false; + if (const auto* trans_a = graph_utils::GetNodeAttribute(gemm_node, "transA"); + trans_a && trans_a->i() != 0) + return false; + if (const auto* trans_b = graph_utils::GetNodeAttribute(gemm_node, "transB"); + trans_b && trans_b->i() != 0) return false; + + const auto& inputs = gemm_node.InputDefs(); + if (inputs.size() > 2 && inputs[2] && inputs[2]->Exists()) { + // Bias exists — beta must be 1.0 + if (const auto* beta_attr = graph_utils::GetNodeAttribute(gemm_node, "beta"); + beta_attr && std::abs(beta_attr->f() - 1.0f) > 1e-6f) + return false; + + // Bias shape must be [N] where N = weight dim 1. Prefer reading N and + // bias length from constant initializers when available, and fall back to + // NodeArg::Shape(). + const auto* weight_arg = weight_dq_node.InputDefs()[0]; + const auto* weight_initializer = graph.GetConstantInitializer(weight_arg->Name(), true); + int64_t N = -1; + + if (weight_initializer) { + if (weight_initializer->dims_size() != 2) { + return false; + } + N = weight_initializer->dims(1); + } else { + const auto* weight_shape = weight_arg->Shape(); + if (!weight_shape || weight_shape->dim_size() != 2 || + !utils::HasDimValue(weight_shape->dim(1))) { + return false; + } + N = weight_shape->dim(1).dim_value(); + } + + const auto* bias_arg = inputs[2]; + const auto* bias_initializer = graph.GetConstantInitializer(bias_arg->Name(), true); + + if (bias_initializer) { + if (bias_initializer->dims_size() != 1 || + bias_initializer->dims(0) != N) { + return false; + } + } else { + const auto* bias_shape = bias_arg->Shape(); + if (!bias_shape || bias_shape->dim_size() != 1 || + !utils::HasDimValue(bias_shape->dim(0)) || + bias_shape->dim(0).dim_value() != N) { + return false; + } + } } return true; @@ -637,18 +731,55 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod } const auto& graph = graph_viewer.GetGraph(); + const bool is_gemm = node.OpType() == "Gemm"; + + if (is_gemm) { + // Gemm: accept 1 DQ (weight only) or 2 DQs (weight + bias). + if (dq_nodes.size() < 1 || dq_nodes.size() > 2) { + return false; + } + } else { + // MatMul: exactly 1 DQ input + if (dq_nodes.size() != 1) { + return false; + } + } - // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output - if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { + // Find the weight DQ node — the one feeding input 1 (B) + const Node* weight_dq = nullptr; + for (const auto* dq : dq_nodes) { + if (node.InputDefs()[1] == dq->OutputDefs()[0]) { + weight_dq = dq; + break; + } + } + + if (!weight_dq) { return false; } - // DQ must be MatMul's the second input - if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { + // Weight DQ must have exactly 1 output edge and not be a graph output + if (!optimizer_utils::CheckOutputEdges(graph, *weight_dq, 1)) { return false; } - return ValidateBlockwiseDQForMatMulNBits(graph, *dq_nodes[0]); + if (is_gemm) { + // If there's a second DQ node (for bias), it must feed input 2 + if (dq_nodes.size() == 2) { + const Node* bias_dq = (dq_nodes[0] == weight_dq) ? dq_nodes[1] : dq_nodes[0]; + if (node.InputDefs().size() <= 2 || !node.InputDefs()[2] || + node.InputDefs()[2] != bias_dq->OutputDefs()[0]) { + return false; + } + } + + // Validate Gemm attributes (alpha=1, transA=0, transB=0, beta=1 if bias) + if (!ValidateGemmForDQMatMulNBits(graph, node, *weight_dq)) { + return false; + } + } + + return ValidateDQForMatMulNBits(graph, *weight_dq); } bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, @@ -701,6 +832,13 @@ void GemmSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex); } +void DQMatMulToMatMulNBitsSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { + // Keep only the weight DQ (first entry). If a Gemm has a bias DQ, it will be in + // position 1 — trim it so RemoveNodes does not delete it. The bias DQ's output + // is wired to MatMulNBits input 5 in ProcessNewNode. + builder.input_nodes.resize(1); +} + bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 79c374b301442..10d307b4a003c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -454,11 +454,15 @@ class MatMulSelector : public BaseSelector { compatible_providers) {} }; -// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +// Convert "1 DQ node for input B -> MatMul/Gemm" to "MatMulNBits" class DQMatMulToMatMulNBitsSelector : public BaseSelector { public: explicit DQMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) : BaseSelector(std::make_unique(), compatible_providers) {} + + // Only keep the weight DQ in the selection. Any bias DQ (for Gemm) is excluded + // so that RemoveNodes does not remove it — its output is wired through to MatMulNBits. + void UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const override; }; // Input: DQ nodes for A, B and optional C diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 74b8f8e468097..9f19a20a2e680 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1220,6 +1220,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv); // Opset 20 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, 20, ConstantOfShape); @@ -1316,6 +1318,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Ac class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Atanh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, ConvTranspose); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, double, DeformConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Det); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_float, Dropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_double, Dropout); @@ -3277,6 +3281,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Resize)>, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 20 BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc new file mode 100644 index 0000000000000..f128b0e0182ad --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -0,0 +1,321 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CPU implementation of DeformConv (deformable convolution 2D). + +#include "deform_conv.h" + +#include + +#include "core/common/common.h" +#include "core/util/math_cpuonly.h" +#include "core/common/narrow.h" +#include "core/util/math.h" + +namespace onnxruntime { + +namespace { +// Bilinear interpolation at (h, w). Out-of-bounds samples return 0 (ONNX spec). +// Indices use int (not int64_t) to reduce register pressure and improve occupancy in the hot path. +// Limitation: height and width must not exceed INT_MAX, or casting floor(h)/floor(w) to int may overflow. +// Acceptable in practice: deformable convolution spatial dimensions are typically well below INT_MAX. +template +T BilinearInterpolate(const T* in, int height, int width, T h, T w) { + // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case). + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return static_cast(0); + } + + // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw. + const T h_floor = std::floor(h); + const T w_floor = std::floor(w); + const int h_low = static_cast(h_floor); + const int w_low = static_cast(w_floor); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const T lh = h - h_floor; + const T lw = w - w_floor; + const T hh = static_cast(1) - lh; + const T hw = static_cast(1) - lw; + + // Fast path: all 4 corners in bounds (h in [0, height-1), w in [0, width-1)). + // Most sampling points in deformable conv fall here; avoids 4 per-corner branches. + // [Optimization 3]: Use unsigned comparison to avoid branch on negative height/width. + if (static_cast(h_low) < static_cast(height - 1) && + static_cast(w_low) < static_cast(width - 1)) { + const int base_low = h_low * width; + const int base_high = h_high * width; + return hh * hw * in[base_low + w_low] + + hh * lw * in[base_low + w_high] + + lh * hw * in[base_high + w_low] + + lh * lw * in[base_high + w_high]; + } + + // Slow path: near boundary (one or more of the 4 corners may be out of bounds). + const int base_low = h_low * width; + const int base_high = h_high * width; + const T v1 = (h_low >= 0 && w_low >= 0) ? in[base_low + w_low] : static_cast(0); + const T v2 = (h_low >= 0 && w_high < width) ? in[base_low + w_high] : static_cast(0); + const T v3 = (h_high < height && w_low >= 0) ? in[base_high + w_low] : static_cast(0); + const T v4 = (h_high < height && w_high < width) ? in[base_high + w_high] : static_cast(0); + return hh * hw * v1 + hh * lw * v2 + lh * hw * v3 + lh * lw * v4; +} + +// Deformable Im2Col for a SINGLE image. +// Converts the input image into a matrix suitable for GEMM by sampling with learned offsets. +// Output 'data_col' shape: [C_in * kH * kW, H_out * W_out] +// When UseMask=false, pass nullptr for data_mask; compiler eliminates dead code for mask. +template +void DeformableIm2col( + const T* data_im, // Input image [C, H, W] + const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out] + const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (nullptr when UseMask=false) + int height, int width, // Input spatial dimensions (validated H*W <= INT_MAX) + int64_t kernel_h, int64_t kernel_w, // Kernel dimensions + int64_t pad_h, int64_t pad_w, // Padding (begin) for H and W + int64_t stride_h, int64_t stride_w, // Stride for H and W + int64_t dilation_h, int64_t dilation_w, // Dilation for H and W + int64_t channels, // Input channels + int64_t offset_groups, // Number of offset groups (channels shared per group) + int64_t height_col, int64_t width_col, // Output spatial dimensions (H_out, W_out) + T* data_col, // Output buffer for im2col result + concurrency::ThreadPool* thread_pool) { + const int64_t channel_per_offset_group = channels / offset_groups; + const int64_t kernel_size = kernel_h * kernel_w; + const int64_t output_size = height_col * width_col; + + // Parallelize over (channel, kernel_position) so each task processes one full row of data_col. + // This yields channels*kernel_size tasks, better CPU utilization and cache-friendly sequential writes. + concurrency::ThreadPool::TryParallelFor( + thread_pool, + static_cast(channels * kernel_size), + static_cast(output_size) * 10.0, + [&](ptrdiff_t begin, ptrdiff_t end) { + for (ptrdiff_t idx = begin; idx < end; ++idx) { + // Decompose idx into (c_im, i, j): which channel and kernel position. + const int64_t j = static_cast(idx) % kernel_w; + const int64_t i = (static_cast(idx) / kernel_w) % kernel_h; + const int64_t c_im = static_cast(idx) / kernel_size; + const int64_t offset_grp = c_im / channel_per_offset_group; + + // Output row: one (channel, kernel_pos) across all spatial locations. + T* col_ptr = data_col + static_cast(idx) * output_size; + const T* im_ptr = data_im + c_im * static_cast(height) * width; + + // Offset tensor layout: [offset_grp, 2*kH*kW, H_out, W_out] flattened. + // For (i,j) we use channel indices 2*(i*kW+j) and 2*(i*kW+j)+1 for offset_h, offset_w. + // Precompute pointers to avoid offset_base * output_size multiplication in inner loop. + const int64_t offset_base = + offset_grp * 2 * kernel_size + 2 * (i * kernel_w + j); + const T* ptr_offset_h = data_offset + offset_base * output_size; + const T* ptr_offset_w = data_offset + (offset_base + 1) * output_size; + + // Base terms for h_im, w_im: invariant in inner loop (i, j fixed). + const T base_h = -pad_h + static_cast(i) * dilation_h; + const T base_w = -pad_w + static_cast(j) * dilation_w; + + // Mask pointer; only used when UseMask=true (compiler removes when false). + [[maybe_unused]] const T* ptr_mask = nullptr; + if constexpr (UseMask) { + ptr_mask = data_mask + (offset_grp * kernel_size + i * kernel_w + j) * output_size; + } + + // Loop over output spatial positions. + for (int64_t h_col = 0; h_col < height_col; ++h_col) { + for (int64_t w_col = 0; w_col < width_col; ++w_col) { + const int64_t spatial_idx = h_col * width_col + w_col; + + const T offset_h = ptr_offset_h[spatial_idx]; + const T offset_w = ptr_offset_w[spatial_idx]; + + // Deformed sampling coordinates (fractional, for bilinear interpolation). + const T h_im = h_col * stride_h + base_h + offset_h; + const T w_im = w_col * stride_w + base_w + offset_w; + + // Sample input at deformed location; returns 0 if out of bounds. + T val = BilinearInterpolate(im_ptr, height, width, h_im, w_im); + + // Modulate by mask when UseMask=true; compiled away when false. + // Design choice: we always interpolate then multiply, rather than skip when mask==0. + // Rationale: (1) Skipping adds a branch; unpredictable mask values cause misprediction + // penalties (~15-20 cycles). (2) Straight-line code vectorizes better; conditional + // skip blocks SIMD. (3) Multiplying by 0 is cheap when vectorized. In typical DCN + // usage (moderate mask density), the unconditional path usually wins. + if constexpr (UseMask) { + val *= ptr_mask[spatial_idx]; + } + + col_ptr[spatial_idx] = val; + } + } + } + }); +} + +} // namespace + +template +Status DeformConv::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto* offset = context->Input(2); + const auto* B = context->Input(3); // optional + const auto* mask = context->Input(4); // optional + + DeformConvParams params; + ORT_RETURN_IF_ERROR(DeformConvValidateAndParse( + attrs_, + X->Shape(), + W->Shape(), + offset->Shape(), + B ? &B->Shape() : nullptr, + mask ? &mask->Shape() : nullptr, + params)); + + const int64_t N = params.N; + const int64_t C = params.C; + const int64_t H = params.H; + const int64_t W_in = params.W_in; + const int64_t M = params.M; + const int64_t kH = params.kH; + const int64_t kW = params.kW; + const int64_t pad_h = params.pad_h; + const int64_t pad_w = params.pad_w; + const int64_t stride_h = params.stride_h; + const int64_t stride_w = params.stride_w; + const int64_t dilation_h = params.dilation_h; + const int64_t dilation_w = params.dilation_w; + const int64_t group = params.group; + const int64_t offset_group = params.offset_group; + const int64_t out_h = params.out_h; + const int64_t out_w = params.out_w; + const bool use_mask = params.use_mask; + + // Allocate output tensor [N, M, out_h, out_w]. + const TensorShape Y_shape({N, M, out_h, out_w}); + Tensor* Y = context->Output(0, Y_shape); + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + // Precompute common sizes for the im2col + GEMM pipeline. + const int64_t kernel_size = kH * kW; + const int64_t output_image_size = out_h * out_w; + const int64_t input_image_size = H * W_in; + const int64_t kernel_dim = C / group * kernel_size; // K dimension for GEMM: C/group * kH * kW + + // Col buffer: shape [C*kH*kW, out_h*out_w]. Allocate per-image (process one image at a time) + // to reduce peak memory when N is large; im2col is implemented per-image anyway. + const int64_t col_buffer_size = (C * kernel_size) * output_image_size; + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size)); + + const T* Xdata = X->Data(); + const T* Wdata = W->Data(); + const T* offset_data = offset->Data(); + const T* mask_data = use_mask ? mask->Data() : nullptr; + T* Ydata = Y->MutableData(); + const T* Bdata = (B != nullptr) ? B->Data() : nullptr; + + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + + // Process each image in the batch. + for (int64_t n = 0; n < N; ++n) { + // Step 1: Deformable Im2Col for image n. + // Gather deformed samples into col buffer for GEMM. + const T* X_curr = Xdata + n * (C * input_image_size); + const T* offset_curr = offset_data + n * (offset_group * 2 * kernel_size * output_image_size); + const T* mask_curr = use_mask ? (mask_data + n * (offset_group * kernel_size * output_image_size)) : nullptr; + T* col_buffer_ptr = col_buffer.get(); + + // Dispatch to template instantiation: UseMask=true or false eliminates branch in hot loop. + // Note: pad_h, pad_w are begin-side paddings for coordinate mapping; pad_h_end/pad_w_end + // affect only output size (already baked into out_h, out_w), not im2col sampling. + if (use_mask) { + DeformableIm2col( + X_curr, offset_curr, mask_curr, + static_cast(H), static_cast(W_in), kH, kW, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + C, offset_group, out_h, out_w, + col_buffer_ptr, thread_pool); + } else { + DeformableIm2col( + X_curr, offset_curr, nullptr, + static_cast(H), static_cast(W_in), kH, kW, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + C, offset_group, out_h, out_w, + col_buffer_ptr, thread_pool); + } + + // Step 2: GEMM for each group. Y = W * Col (per group). + for (int64_t g = 0; g < group; ++g) { + // Weight for group g: shape [M/group, C/group, kH, kW], row-major. + const T* weight_g = Wdata + g * (M / group) * kernel_dim; + + // Col rows for group g: layout [C*kH*kW, out_h*out_w], group g spans rows [g*kernel_dim, (g+1)*kernel_dim). + const T* col_g = col_buffer_ptr + g * kernel_dim * output_image_size; + + // Output slice for group g: [n, g*M/group:(g+1)*M/group, out_h, out_w]. + T* Y_g = Ydata + n * M * output_image_size + g * (M / group) * output_image_size; + + // GEMM: Y = W * Col. W [M/group, kernel_dim], Col [kernel_dim, output_image_size]. + math::Gemm( + CblasNoTrans, + CblasNoTrans, + narrow(M / group), // M + narrow(output_image_size), // N + narrow(kernel_dim), // K + static_cast(1), // alpha + weight_g, // A + col_g, // B + static_cast(0), // beta + Y_g, // C + thread_pool, + nullptr); // mlas_backend_kernel_selector_config + } + } + + // Step 3: Add bias if provided (broadcast over spatial dimensions). + if (Bdata != nullptr) { + int64_t total_work = N * M; + concurrency::ThreadPool::TryParallelFor( + thread_pool, static_cast(total_work), static_cast(output_image_size), + [&](ptrdiff_t first, ptrdiff_t last) { + for (ptrdiff_t idx = first; idx < last; ++idx) { + int64_t n = idx / M; + int64_t m = idx % M; + T* Y_ptr = Ydata + n * M * output_image_size + m * output_image_size; + // Eigen vectorized add: Y_ptr += Bdata[m] over all spatial positions. + EigenVectorArrayMap(Y_ptr, narrow(output_image_size)) += Bdata[m]; + } + }); + } + + return Status::OK(); +} + +// Explicit template instantiation for float and double +template class DeformConv; +template class DeformConv; + +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + DeformConv, 19, 21, T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv); \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + DeformConv, 22, T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv) + +REGISTER_DEFORMCONV_KERNEL_TYPED(float) +REGISTER_DEFORMCONV_KERNEL_TYPED(double) + +#undef REGISTER_DEFORMCONV_KERNEL_TYPED + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.h b/onnxruntime/core/providers/cpu/nn/deform_conv.h new file mode 100644 index 0000000000000..c8d7763e58bcb --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/framework/op_node_proto_helper.h" +#include "deform_conv_attributes.h" + +namespace onnxruntime { + +template +class DeformConv : public OpKernel { + public: + explicit DeformConv(const OpKernelInfo& info) : OpKernel(info), attrs_(info) {} + + Status Compute(OpKernelContext* context) const override; + + private: + DeformConvAttributes attrs_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h new file mode 100644 index 0000000000000..8bc891bb4f377 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensor_shape.h" + +namespace onnxruntime { + +// Shared attributes for ONNX DeformConv (opset 19+). +// See https://onnx.ai/onnx/operators/onnx__DeformConv.html +// Used by both CPU and CUDA implementations (CUDA includes from here). +struct DeformConvAttributes { + explicit DeformConvAttributes(const OpKernelInfo& info) { + // Optional attributes. + // If not present, they will be empty/default, and handled in Compute/ComputeInternal. + (void)info.GetAttrs("kernel_shape", kernel_shape); + (void)info.GetAttrs("strides", strides); + (void)info.GetAttrs("pads", pads); + (void)info.GetAttrs("dilations", dilations); + group = info.GetAttrOrDefault("group", 1); + offset_group = info.GetAttrOrDefault("offset_group", 1); + } + + TensorShapeVector kernel_shape; + TensorShapeVector strides; + TensorShapeVector pads; + TensorShapeVector dilations; + int64_t group{1}; + int64_t offset_group{1}; +}; + +// Parsed and validated parameters from DeformConv inputs. +// Used by both CPU and CUDA implementations. +// Field names align with ONNX DeformConv spec: https://onnx.ai/onnx/operators/onnx__DeformConv.html +struct DeformConvParams { + // Input X shape (N, C, H, W) + int64_t N{0}; // Batch size + int64_t C{0}; // Number of input channels + int64_t H{0}; // Input height + int64_t W_in{0}; // Input width (W_in to avoid collision with weight W) + + // Weight W shape (oC, C/group, kH, kW) + int64_t M{0}; // Number of output channels (oC) + int64_t kH{0}; // Kernel height + int64_t kW{0}; // Kernel width + + // Pads [x1_begin, x2_begin, x1_end, x2_end] for spatial axes H, W + int64_t pad_h{0}; + int64_t pad_w{0}; + int64_t pad_h_end{0}; + int64_t pad_w_end{0}; + + // Strides and dilations along each spatial axis (default 1) + int64_t stride_h{1}; + int64_t stride_w{1}; + int64_t dilation_h{1}; + int64_t dilation_w{1}; + + // Attributes: C and oC must be divisible by group; C must be divisible by offset_group + int64_t group{1}; // Number of groups for input/output channels + int64_t offset_group{1}; // Number of groups of offset + + // Output Y shape (N, oC, oH, oW) + int64_t out_h{0}; // Output height (oH) + int64_t out_w{0}; // Output width (oW) + + bool use_mask{false}; // Whether optional mask input is provided +}; + +// Validates inputs and parses attributes into params. +// Returns Status::OK() on success; on failure, params may be partially filled. +inline Status DeformConvValidateAndParse( + const DeformConvAttributes& attrs, + const TensorShape& X_shape, + const TensorShape& W_shape, + const TensorShape& offset_shape, + const TensorShape* B_shape, + const TensorShape* mask_shape, + DeformConvParams& params) { + ORT_RETURN_IF_NOT(X_shape.NumDimensions() == 4, "Input X must be 4D (N, C, H, W)."); + ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); + ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); + + // Parse input shapes + params.N = X_shape[0]; + params.C = X_shape[1]; + params.H = X_shape[2]; + params.W_in = X_shape[3]; + params.M = W_shape[0]; + ORT_RETURN_IF_NOT(params.N >= 0, "Batch size N must be non-negative."); + ORT_RETURN_IF_NOT(params.C > 0, "Input channels C must be positive."); + ORT_RETURN_IF_NOT(params.M > 0, "Output channels M (oC) must be positive."); + ORT_RETURN_IF_NOT(W_shape[1] > 0, "Weight W must have positive in-channels (W_shape[1] = C/group)."); + + // Handle kernel shape inference. If kernel_shape is provided, it must match weight spatial dims + // to avoid GEMM using wrong K and potential out-of-bounds reads from the weight buffer. + const int64_t W_kH = W_shape[2]; + const int64_t W_kW = W_shape[3]; + if (!attrs.kernel_shape.empty()) { + ORT_RETURN_IF_NOT(attrs.kernel_shape.size() == 2, + "kernel_shape must be absent or have exactly 2 values (kH, kW) for 2D DeformConv."); + ORT_RETURN_IF_NOT(attrs.kernel_shape[0] == W_kH && attrs.kernel_shape[1] == W_kW, + "kernel_shape must match weight spatial dimensions (W_shape[2], W_shape[3])."); + params.kH = attrs.kernel_shape[0]; + params.kW = attrs.kernel_shape[1]; + } else { + params.kH = W_kH; + params.kW = W_kW; + } + + // DeformConv is 2D-only: when an attribute is present, require exact length to avoid silently misinterpreting malformed models. + params.pad_h = params.pad_w = params.pad_h_end = params.pad_w_end = 0; + if (!attrs.pads.empty()) { + ORT_RETURN_IF_NOT(attrs.pads.size() == 4, + "pads must be absent or have exactly 4 values [pad_h_begin, pad_w_begin, pad_h_end, pad_w_end] for 2D DeformConv."); + params.pad_h = attrs.pads[0]; + params.pad_w = attrs.pads[1]; + params.pad_h_end = attrs.pads[2]; + params.pad_w_end = attrs.pads[3]; + ORT_RETURN_IF_NOT(params.pad_h >= 0 && params.pad_w >= 0 && params.pad_h_end >= 0 && params.pad_w_end >= 0, + "Pads must be non-negative (ONNX spec)."); + } + + if (!attrs.strides.empty()) { + ORT_RETURN_IF_NOT(attrs.strides.size() == 2, + "strides must be absent or have exactly 2 values [stride_h, stride_w] for 2D DeformConv."); + params.stride_h = attrs.strides[0]; + params.stride_w = attrs.strides[1]; + } else { + params.stride_h = params.stride_w = 1; + } + + if (!attrs.dilations.empty()) { + ORT_RETURN_IF_NOT(attrs.dilations.size() == 2, + "dilations must be absent or have exactly 2 values [dilation_h, dilation_w] for 2D DeformConv."); + params.dilation_h = attrs.dilations[0]; + params.dilation_w = attrs.dilations[1]; + } else { + params.dilation_h = params.dilation_w = 1; + } + params.group = attrs.group; + params.offset_group = attrs.offset_group; + params.use_mask = (mask_shape != nullptr); + + // Validate attributes + ORT_RETURN_IF_NOT(params.stride_h > 0 && params.stride_w > 0, "Strides must be positive."); + ORT_RETURN_IF_NOT(params.dilation_h > 0 && params.dilation_w > 0, "Dilations must be positive."); + ORT_RETURN_IF_NOT(params.kH > 0 && params.kW > 0, "Kernel shape must be positive."); + ORT_RETURN_IF_NOT(params.group > 0, "group must be positive"); + ORT_RETURN_IF_NOT(params.offset_group > 0, "offset_group must be positive"); + + params.out_h = (params.H + params.pad_h + params.pad_h_end - params.dilation_h * (params.kH - 1) - 1) / params.stride_h + 1; + params.out_w = (params.W_in + params.pad_w + params.pad_w_end - params.dilation_w * (params.kW - 1) - 1) / params.stride_w + 1; + ORT_RETURN_IF_NOT(params.out_h >= 0 && params.out_w >= 0, "Computed output spatial size must be non-negative."); + + // CPU BilinearInterpolate uses int for indices (for performance optimization); W <= INT_MAX / (H+1) covers all index math. + ORT_RETURN_IF_NOT(params.H >= 0 && params.W_in >= 0, "Input spatial dimensions H and W must be non-negative."); + ORT_RETURN_IF_NOT(params.W_in <= static_cast(INT_MAX) / (params.H + 1), + "Input (H+1)*W must not exceed INT_MAX (for performance optimization)."); + + // Validate tensor shapes (use division to avoid int64 overflow in offset_group * 2 * kH * kW). + ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size."); + const int64_t offset_block = 2 * params.kH * params.kW; + ORT_RETURN_IF_NOT(offset_block > 0 && offset_shape[1] % offset_block == 0 && + offset_shape[1] / offset_block == params.offset_group, + "Offset channel count must be offset_group * 2 * kH * kW."); + ORT_RETURN_IF_NOT(offset_shape[2] == params.out_h, "Offset spatial height must match output oH."); + ORT_RETURN_IF_NOT(offset_shape[3] == params.out_w, "Offset spatial width must match output oW."); + ORT_RETURN_IF_NOT(params.C % params.offset_group == 0, "Input channels must be divisible by offset_group."); + ORT_RETURN_IF_NOT(params.C == W_shape[1] * params.group, "Input channels must match weight in channels * group."); + ORT_RETURN_IF_NOT(params.M % params.group == 0, "Output channels must be divisible by group."); + + if (B_shape != nullptr) { + ORT_RETURN_IF_NOT(B_shape->NumDimensions() == 1, "Bias B must be 1D."); + ORT_RETURN_IF_NOT((*B_shape)[0] == params.M, "Bias B must have shape [M] (M = number of output channels)."); + } + + // Validate mask if present + if (params.use_mask) { + ORT_RETURN_IF_NOT(mask_shape->NumDimensions() == 4, "Mask must be 4D."); + ORT_RETURN_IF_NOT((*mask_shape)[0] == params.N, "Mask batch size must match input batch size."); + const int64_t mask_block = params.kH * params.kW; + ORT_RETURN_IF_NOT(mask_block > 0 && (*mask_shape)[1] % mask_block == 0 && + (*mask_shape)[1] / mask_block == params.offset_group, + "Mask channel count must be offset_group * kH * kW."); + ORT_RETURN_IF_NOT((*mask_shape)[2] == params.out_h, "Mask spatial height must match output oH."); + ORT_RETURN_IF_NOT((*mask_shape)[3] == params.out_w, "Mask spatial width must match output oW."); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.h b/onnxruntime/core/providers/cpu/object_detection/roialign.h index bb97de158369b..4ce4825e1d78c 100644 --- a/onnxruntime/core/providers/cpu/object_detection/roialign.h +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.h @@ -129,6 +129,10 @@ class RoiAlignBase { std::string coordinate_transformation_mode; if (info.template GetAttr("coordinate_transformation_mode", &coordinate_transformation_mode).IsOK()) { half_pixel_ = coordinate_transformation_mode == "half_pixel"; + } else { + // For opset 16+, the default is "half_pixel" per ONNX spec. + // For opset 10 (which has no coordinate_transformation_mode attribute), false is correct. + half_pixel_ = info.node().SinceVersion() >= 16; } if (mode_ == RoiAlignMode::max && sampling_ratio_ != 1) { diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index d55973eda180f..e34af83d1f29e 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -88,16 +88,9 @@ Status Gelu::Compute(OpKernelContext* context) const { T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - for (int64_t i = 0; i < count; i++) { - T value = p_input[i]; - p_output[i] = value * static_cast(M_SQRT1_2); - } - - MlasComputeErf(p_output, p_output, narrow(count)); - - for (int64_t i = 0; i < count; i++) { - p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); - } + // MlasComputeGeluErf requires distinct input/output buffers. This + // call uses disjoint slices from the input and output tensors. + MlasComputeGeluErf(p_input, p_output, narrow(count)); }, 0); return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h index 13238028d878a..14a070609a69b 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.h +++ b/onnxruntime/core/providers/cpu/tensor/gelu.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once + namespace onnxruntime { template diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index 7dcf88133e967..ff5498c0b4644 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -265,8 +265,17 @@ class UpsampleBase { if (scales_input_idx_ > 0) { const Tensor* scale; bool get_scale = info.TryGetConstantInput(scales_input_idx_, &scale); - auto x_shape = node.InputDefs()[0]->Shape(); - int64_t rank = x_shape ? x_shape->dim_size() : -1; + int64_t rank = -1; + if constexpr (std::is_same_v) { + auto x_shape = node.InputDefs()[0]->Shape(); + if (x_shape != nullptr) { + rank = x_shape->dim_size(); + } + } else { + auto type_info = info.GetKernelInfo().GetInputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + rank = static_cast(tensor_info.GetDimensionsCount()); + } if (get_scale && scale->Shape().Size() > 0 && ((opset < 18) || (rank > 0 && opset >= 18))) { ORT_THROW_IF_ERROR(ParseScalesData(scale, scales_, rank)); scales_cached_ = true; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b87cf8cbc16c1..4c735fa2d5650 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -944,8 +944,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, ReverseSequence); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, float, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, MLFloat16, RoiAlign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, ThresholdedRelu); @@ -1441,10 +1444,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, bool, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Resize); @@ -1452,9 +1455,16 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, uint8_t, Resize); // Opset 19 +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, bool, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, DeformConv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Cast); @@ -1573,6 +1583,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, bool, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, ConstantOfShape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Identity); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, If); @@ -1596,6 +1610,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, DeformConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSigmoid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSigmoid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSigmoid); @@ -1608,6 +1626,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GRU); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, GRU); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, GRU); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, RoiAlign); // Opset 23. class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention); @@ -1639,6 +1661,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E5M2, Cast); #endif +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Transpose); @@ -1663,10 +1689,18 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, bool, Pad); // Opset 25. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Squeeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Unsqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, ConstantOfShape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, If); @@ -2063,8 +2097,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2560,10 +2597,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2571,9 +2608,16 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 19-20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2641,6 +2685,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 21 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // TODO(fajin): support other quantized types BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2706,6 +2754,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2728,8 +2780,16 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 23 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2782,10 +2842,18 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 25 BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc new file mode 100644 index 0000000000000..7a0b896acfe01 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -0,0 +1,331 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CUDA implementation of DeformConv (deformable convolution 2D). + +#include "core/providers/shared_library/provider_api.h" +#include "deform_conv.h" +#include "deform_conv_impl.h" + +#include + +#include "core/common/narrow.h" +#include "core/common/span_utils.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace cuda { + +namespace { + +constexpr int kMaxParallelImgs = 32; + +// Returns the greatest divisor of n that is <= bound. Used to choose uniform batch chunk sizes. +// Fast path: if n % bound == 0 (common for batch 32/64/128), return immediately. +// When n >= bound^2, linear scan from bound down is O(bound). Otherwise divisor enumeration +// from 1 to sqrt(n) is O(sqrt(n)). Uses integer comparison (no sqrt) for branch decision. +int GetGreatestDivisorBelowBound(int n, int bound) { + if (bound <= 0 || n <= 0) return 1; + if (n % bound == 0) return bound; // Fast path: batch is multiple of target + + // n >= bound^2 <=> bound <= sqrt(n) => linear scan is cheaper + if (static_cast(n) >= static_cast(bound) * bound) { + for (int k = bound - 1; k > 1; --k) { + if (n % k == 0) return k; + } + } else { + // n < bound^2 <=> bound > sqrt(n) => divisor enumeration is cheaper + int best = 1; + for (int i = 1; static_cast(i) * i <= static_cast(n); ++i) { + if (n % i != 0) continue; + const int q = n / i; + if (q <= bound && q > best) best = q; + if (i <= bound && i > best) best = i; + } + return best; + } + return 1; +} + +// Returns the maximum temp memory (bytes) allowed for DeformConv's im2col + GEMM buffers. +// Uses a fraction of total GPU memory to avoid OOM while leaving room for weights, activations, +// and other ops. No CUDA API is called; total_global_mem is expected from cached device props. +// +// Formula: +// budget = total_global_mem * kFraction +// return clamp(budget, kMin, kMax) +// with kFraction = 0.1 (10%), kMin = 32 MiB, kMax = 2 GiB. +// +// Example results (effective_max_temp after clamp): +// GPU | totalGlobalMem | effective_max_temp +// -----------------|----------------|-------------------- +// A100 80GB | 80 GiB | 2 GiB (capped) +// RTX 5080 16GB | 16 GiB | 1.6 GiB +// RTX 4090 24GB | 24 GiB | 2 GiB (capped) +// RTX 3080 10GB | 10 GiB | 1 GiB +// GTX 1060 6GB | 6 GiB | 614.4 MiB +// GTX 1050 4GB | 4 GiB | 409.6 MiB +// Jetson 2GB | 2 GiB | 204.8 MiB +size_t GetDeformConvEffectiveMaxTempBytes(size_t total_global_mem) { + constexpr double kFraction = 0.1; + constexpr size_t kMin = 32ULL * 1024 * 1024; + constexpr size_t kMax = 2ULL * 1024 * 1024 * 1024; + size_t budget = static_cast(static_cast(total_global_mem) * kFraction); + return std::clamp(budget, kMin, kMax); +} + +// Returns how many images to process in parallel per batch chunk for DeformConv. +// Chooses the largest divisor of batch size N that fits in the temp budget and does not +// exceed kMaxParallelImgs, so that batch dimension is split evenly (no remainder). +// Note: if N is prime and N > target_parallel_imgs, the greatest divisor <= target_parallel_imgs is 1, +// so batching is effectively disabled (single-image chunks). +// +// Formulas: +// kernel_size = kH * kW +// output_image_size = out_h * out_w +// bytes_per_image = output_image_size * (C * kernel_size + M / group) * sizeof(T) +// (temp bytes per image: im2col col buffer + GEMM output buffer per output position) +// max_parallel_imgs_mem = max(1, floor(effective_max_temp / bytes_per_image)) +// target_parallel_imgs = min(kMaxParallelImgs, max_parallel_imgs_mem) +// return GetGreatestDivisorBelowBound(N, target_parallel_imgs) +template +int GetNParallelImgs(const DeformConvParams& params, size_t total_global_mem) { + const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(total_global_mem); + const int64_t kernel_size = params.kH * params.kW; + const int64_t output_image_size = params.out_h * params.out_w; + const size_t bytes_per_image = SafeInt(output_image_size) * (params.C * kernel_size + params.M / params.group) * sizeof(T); + const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image))); + const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); + return GetGreatestDivisorBelowBound(static_cast(params.N), target_parallel_imgs); +} + +} // namespace + +template +Status DeformConv::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto* offset = context->Input(2); + const auto* B = context->Input(3); + const auto* mask = context->Input(4); + + DeformConvParams params; + ORT_RETURN_IF_ERROR(DeformConvValidateAndParse( + attrs_, + X->Shape(), + W->Shape(), + offset->Shape(), + B ? &B->Shape() : nullptr, + mask ? &mask->Shape() : nullptr, + params)); + + const int64_t N = params.N; + const int64_t C = params.C; + const int64_t H = params.H; + const int64_t W_in = params.W_in; + const int64_t M = params.M; + const int64_t kH = params.kH; + const int64_t kW = params.kW; + const int64_t pad_h = params.pad_h; + const int64_t pad_w = params.pad_w; + const int64_t stride_h = params.stride_h; + const int64_t stride_w = params.stride_w; + const int64_t dilation_h = params.dilation_h; + const int64_t dilation_w = params.dilation_w; + const int64_t group = params.group; + const int64_t offset_group = params.offset_group; + const int64_t out_h = params.out_h; + const int64_t out_w = params.out_w; + const bool use_mask = params.use_mask; + + Tensor* Y = context->Output(0, {N, M, out_h, out_w}); + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + const int n_parallel_imgs = GetNParallelImgs(params, GetDeviceProp().totalGlobalMem); + + const int64_t kernel_size = kH * kW; + const int64_t output_image_size = out_h * out_w; + const int64_t input_image_size = H * W_in; + const int64_t kernel_dim = (C / group) * kernel_size; + + const int64_t col_stride = static_cast(n_parallel_imgs) * output_image_size; + const int64_t col_buffer_size = (C * kernel_size) * col_stride; + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size)); + // Removed col_transposed allocation as we avoid physical transpose. + auto gemm_output_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt((M / group) * col_stride)); + + const T* Xdata = X->Data(); + const T* Wdata = W->Data(); + const T* offset_data = offset->Data(); + const T* mask_data = use_mask ? mask->Data() : nullptr; + T* Ydata = Y->MutableData(); + const T* Bdata = (B != nullptr) ? B->Data() : nullptr; + + cudaStream_t stream = Stream(context); + cublasHandle_t cublas = GetCublasHandle(context); + const cudaDeviceProp& device_prop = GetDeviceProp(); + CudaT alpha = ToCudaType::FromFloat(1.0f); + CudaT beta = ToCudaType::FromFloat(0.0f); + + for (int64_t b = 0; b < N; b += n_parallel_imgs) { + const int cur_parallel = static_cast(std::min(static_cast(n_parallel_imgs), N - b)); + const int64_t cur_out_size = static_cast(cur_parallel) * output_image_size; + + const T* X_block = Xdata + b * (C * input_image_size); + const T* offset_block = offset_data + b * (offset_group * 2 * kernel_size * output_image_size); + const T* mask_block = use_mask ? (mask_data + b * (offset_group * kernel_size * output_image_size)) : nullptr; + + ORT_RETURN_IF_ERROR(DeformConvIm2ColImpl( + stream, + X_block, + offset_block, + mask_block, + col_buffer.get(), + cur_parallel, + C, + H, + W_in, + kH, + kW, + out_h, + out_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + offset_group, + use_mask)); + + // GEMM layout trick: compute Y = W * Col without physical transpose. + // + // Our data is row-major: W [M/group, kernel_dim], Col [kernel_dim, cur_out_size], Y [M/group, cur_out_size]. + // cuBLAS is column-major. Key insight: row-major A[M,K] in memory equals column-major A^T[K,M]. + // We compute Y^T = Col^T * W^T by passing Col as A and W as B, both OP_N (no transpose): + // - Col (row [kernel_dim, cur_out_size]) -> cuBLAS interprets as col-major [cur_out_size, kernel_dim] = Col^T + // - W (row [M/group, kernel_dim]) -> cuBLAS interprets as col-major [kernel_dim, M/group] = W^T + // - C = A*B = Col^T * W^T = (W*Col)^T = Y^T; C is col-major [cur_out_size, M/group] = Y in row-major + // + // m=cur_out_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim, ldc=cur_out_size. + // + // cur_parallel==1: cur_out_size==output_image_size, C layout (pos, channel) matches NCHW Y_g[0,ch,pos] -> write + // directly into Y_g. Use strided batched for all groups in one call. + // cur_parallel>1: layouts differ -> write to gemm_output_buffer, then DeformConvCopyGemmOutputRowMajorToNCHW. + + const bool gemm_writes_directly = (cur_parallel == 1); + if (gemm_writes_directly) { + // Strided batched: one call for all groups. Strides between batches: + const int64_t stride_col = kernel_dim * col_stride; // = kernel_dim * output_image_size when cur_parallel==1 + const int64_t stride_weight = (M / group) * kernel_dim; + const int64_t stride_y = (M / group) * output_image_size; + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + narrow(output_image_size), + narrow(M / group), + narrow(kernel_dim), + &alpha, + reinterpret_cast(col_buffer.get()), + narrow(output_image_size), + stride_col, + reinterpret_cast(Wdata), + narrow(kernel_dim), + stride_weight, + &beta, + reinterpret_cast(Ydata + b * M * output_image_size), + narrow(output_image_size), + stride_y, + narrow(group), + device_prop, + UseTF32())); + } else { + // cur_parallel>1: GEMM output layout differs from NCHW; write to buffer then copy per group. + for (int64_t g = 0; g < group; ++g) { + const T* W_g = Wdata + g * (M / group) * kernel_dim; + const T* col_g = col_buffer.get() + g * kernel_dim * col_stride; + T* Y_g = Ydata + b * M * output_image_size + g * (M / group) * output_image_size; + + CUBLAS_RETURN_IF_ERROR((cublasGemmHelper( + cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + narrow(cur_out_size), + narrow(M / group), + narrow(kernel_dim), + &alpha, + reinterpret_cast(col_g), + narrow(cur_out_size), + reinterpret_cast(W_g), + narrow(kernel_dim), + &beta, + reinterpret_cast(gemm_output_buffer.get()), + narrow(cur_out_size), + device_prop, + UseTF32()))); + + ORT_RETURN_IF_ERROR(DeformConvCopyGemmOutputRowMajorToNCHW( + stream, + gemm_output_buffer.get(), + Y_g, + M, + M / group, + output_image_size, + cur_parallel)); + } + } + } + + if (Bdata != nullptr) { + ORT_RETURN_IF_ERROR(DeformConvAddBiasImpl(stream, Ydata, Bdata, N, M, out_h, out_w)); + } + + return Status::OK(); +} + +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + DeformConv, \ + kOnnxDomain, \ + 19, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DeformConv, \ + kOnnxDomain, \ + 22, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv) + +REGISTER_DEFORMCONV_KERNEL_TYPED(float); +REGISTER_DEFORMCONV_KERNEL_TYPED(double); +REGISTER_DEFORMCONV_KERNEL_TYPED(MLFloat16); + +// BFloat16 only for opset 22; opset 19-21 do not support BFloat16. +ONNX_OPERATOR_TYPED_KERNEL_EX( + DeformConv, + kOnnxDomain, + 22, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + DeformConv); + +#undef REGISTER_DEFORMCONV_KERNEL_TYPED + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.h b/onnxruntime/core/providers/cuda/nn/deform_conv.h new file mode 100644 index 0000000000000..fa564641d4b98 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/deform_conv_attributes.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace cuda { + +template +class DeformConv final : public CudaKernel { + public: + explicit DeformConv(const OpKernelInfo& info) : CudaKernel(info), attrs_(info) {} + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + DeformConvAttributes attrs_; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeformConv); +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu new file mode 100644 index 0000000000000..7b3666fca810b --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -0,0 +1,512 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CUDA implementation of DeformConv: deformable im2col kernel + bilinear interpolation. +// Reference: torchvision deform_conv2d_kernel.cu, ONNX DeformConv spec. + +#include "deform_conv_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/fast_divmod.h" +#include "core/common/float16.h" +#include +#include +#include + +namespace onnxruntime { +namespace cuda { + +namespace { + +constexpr int kDeformConvThreadsPerBlock = 256; + +template +struct DeformConvKSize { + static constexpr int value = N; +}; + +// Calculate grid size with a safety limit to prevent overflow. +// Since we use grid-stride loops in kernels, limiting the grid size is safe. +inline int GetGridSize(size_t n, size_t threads_per_block) { + size_t blocks_needed = (n + threads_per_block - 1) / threads_per_block; + return static_cast(std::min(blocks_needed, static_cast(std::numeric_limits::max()))); +} + +// __ldg has no overload for BFloat16*; use 16-bit load + FromBits. Other types use __ldg directly. +template +__device__ __inline__ T DeformConvLdg(const T* p) { + return __ldg(p); +} +template <> +__device__ __inline__ BFloat16 DeformConvLdg(const BFloat16* p) { + return BFloat16::FromBits(__ldg(reinterpret_cast(p))); +} + +// Traits for bilinear interpolation math: +// - ComputeT: type used for coordinate/weight math (float for half/BFloat16, T otherwise) +// - Load: load one element and convert to ComputeT +// - ToResult: convert ComputeT result back to T +// - Zero: zero value of T +template +struct DeformConvBilinearTraits { + using ComputeT = T; + + __device__ static __inline__ ComputeT Load(const T* p) { + return __ldg(p); + } + + __device__ static __inline__ T ToResult(ComputeT v) { + return v; + } + + __device__ static __inline__ T Zero() { + return static_cast(0); + } +}; + +template <> +struct DeformConvBilinearTraits { + using ComputeT = float; + + __device__ static __inline__ ComputeT Load(const half* p) { + return __half2float(__ldg(p)); + } + + __device__ static __inline__ half ToResult(ComputeT v) { + return __float2half(v); + } + + __device__ static __inline__ half Zero() { + return __float2half(0.0f); + } +}; + +template <> +struct DeformConvBilinearTraits { + using ComputeT = float; + + __device__ static __inline__ ComputeT Load(const BFloat16* p) { + return static_cast(DeformConvLdg(p)); + } + + __device__ static __inline__ BFloat16 ToResult(ComputeT v) { + return BFloat16(v); + } + + __device__ static __inline__ BFloat16 Zero() { + return BFloat16(0.0f); + } +}; + +// Bilinear interpolation at (h, w). Returns 0 if out of bounds (ONNX spec). +// Indices h_low, w_low, h_high, w_high use int (not int64_t) to reduce register pressure and +// improve occupancy in the hot path. Limitation: (H+1)*W must not exceed INT_MAX; this is +// validated on the host side in DeformConvValidateAndParse to guarantee index math in int +// does not overflow. For half/BFloat16, coordinate and weight math use float via +// DeformConvBilinearTraits to avoid precision loss. We keep floor() results in CoordT and +// cast to int only for indices (h_low/w_low), which avoids unnecessary CoordT->int->CoordT +// round trips when computing lh/lw/hh/hw. +template +__device__ __inline__ T BilinearInterpolate( + const T* in, + int height, + int width, + typename DeformConvBilinearTraits::ComputeT h, + typename DeformConvBilinearTraits::ComputeT w) { + using Traits = DeformConvBilinearTraits; + using CoordT = typename Traits::ComputeT; + + // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case). + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return Traits::Zero(); + } + + // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw. + CoordT h_floor = _Floor(h); + CoordT w_floor = _Floor(w); + int h_low = static_cast(h_floor); + int w_low = static_cast(w_floor); + int h_high = h_low + 1; + int w_high = w_low + 1; + + CoordT lh = h - h_floor; + CoordT lw = w - w_floor; + CoordT hh = static_cast(1) - lh; + CoordT hw = static_cast(1) - lw; + + // [Optimization 3]: Avoid a second multiply for base_high. + // Original code computed both bases as: + // base_low = h_low * width; + // base_high = h_high * width; + // Since h_high = h_low + 1, we can rewrite base_high as base_low + width and + // save one integer multiply in the hot path: + // base_low = h_low * width; + // base_high = base_low + width; + int base_low = h_low * width; + int base_high = base_low + width; + + CoordT v1 = (h_low >= 0 && w_low >= 0) ? Traits::Load(in + base_low + w_low) : static_cast(0); + CoordT v2 = (h_low >= 0 && w_high < width) ? Traits::Load(in + base_low + w_high) : static_cast(0); + CoordT v3 = (h_high < height && w_low >= 0) ? Traits::Load(in + base_high + w_low) : static_cast(0); + CoordT v4 = (h_high < height && w_high < width) ? Traits::Load(in + base_high + w_high) : static_cast(0); + + CoordT w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + return Traits::ToResult(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); +} + +// kH/kW = -1 means dynamic (runtime); >= 0 means compile-time constant for loop unrolling. +template +__global__ void DeformableIm2ColKernel( + IndexT num_kernels, + const T* __restrict__ input, + const T* __restrict__ offset, + const T* __restrict__ mask, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t channels, + int64_t offset_group, + DivMod out_h_div, + DivMod out_w_div, + DivMod parallel_imgs_div, + DivMod channel_per_offset_grp_div, + bool use_mask, + T* __restrict__ data_col) { + constexpr bool is_fixed = (kH >= 0 && kW >= 0); + const int64_t h_dim = is_fixed ? kH : weight_h; + const int64_t w_dim = is_fixed ? kW : weight_w; + + // Reconstruct dimensions from DivMod objects + const int64_t out_h = out_h_div.d_; + const int64_t out_w = out_w_div.d_; + const int64_t parallel_imgs = parallel_imgs_div.d_; + + const int64_t out_size = out_h * out_w; + // The stride for data_col is (parallel_imgs * out_h * out_w) + const int64_t col_stride = parallel_imgs * out_size; + + using CoordT = typename DeformConvBilinearTraits::ComputeT; + + for (IndexT index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { + IndexT val = index; + IndexT out_x, out_y, out_b, in_c; + + // Fast division/modulo to recover coordinates + out_w_div.divmod(val, val, out_x); + out_h_div.divmod(val, val, out_y); + parallel_imgs_div.divmod(val, in_c, out_b); + + // [Optimization 3] Avoid expensive division if offset_group is 1 (very common case). + IndexT offset_grp = 0; + if (offset_group > 1) { + IndexT dummy; + channel_per_offset_grp_div.divmod(in_c, offset_grp, dummy); + } + + // [Optimization 2] Common Subexpression Elimination (CSE) & Pointer Arithmetic + // Pre-calculate base pointers to reduce integer arithmetic inside the inner loops. + + // 1. Input pointer base for this batch and channel. + const T* input_ptr = input + static_cast(out_b) * (channels * height * width) + static_cast(in_c) * (height * width); + + // 2. Spatial index in the output feature map. + const int64_t spatial_idx = static_cast(out_y) * out_w + static_cast(out_x); + + // 3. Offset pointer base calculation. + // Layout: (N, offset_groups, 2*KH*KW, OH, OW) + // We pre-calculate the pointer to the start of the specific (n, g) block, plus spatial_idx. + const int64_t offset_group_block_size = 2 * h_dim * w_dim * out_size; + const T* offset_ptr_base = offset + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * offset_group_block_size + spatial_idx; + + // 4. Mask pointer base calculation (if used). + // Layout: (N, offset_groups, KH*KW, OH, OW) + const T* mask_ptr_base = nullptr; + if (use_mask) { + const int64_t mask_group_block_size = h_dim * w_dim * out_size; + mask_ptr_base = mask + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * mask_group_block_size + spatial_idx; + } + + // 5. Output pointer base calculation. + // data_col Layout: (C * KH * KW, N * OH * OW) + // The current thread writes to the column `c_col` = (b * OH * OW) + spatial_idx. + // The starting row for this channel is `in_c * KH * KW`. + const int64_t c_col = static_cast(out_b) * out_size + spatial_idx; + T* data_col_ptr_base = data_col + (static_cast(in_c) * h_dim * w_dim) * col_stride + c_col; + + // 6. Pre-calculate invariant coordinate parts. + // Use float for coordinate math when T is half or BFloat16 to avoid precision loss. + const CoordT base_h_im = static_cast(out_y * stride_h - pad_h); + const CoordT base_w_im = static_cast(out_x * stride_w - pad_w); + + auto process_kernel_point = [&](int64_t i, int64_t j) { + const int64_t kernel_idx = i * w_dim + j; + T mask_val = static_cast(1); + if (use_mask) { + // Access mask using pre-calculated base and stride. + mask_val = DeformConvLdg(mask_ptr_base + kernel_idx * out_size); + } + + // Calculate offset pointers relative to the base. + // The offset tensor stores (y_offset, x_offset) pairs for each kernel weight. + // Stride between y_offset and x_offset is `out_size`. + const int64_t offset_offset_idx = (2 * kernel_idx) * out_size; + + const CoordT offset_h = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx)); + const CoordT offset_w = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx + out_size)); + + const CoordT h_im = base_h_im + static_cast(i * dilation_h) + offset_h; + const CoordT w_im = base_w_im + static_cast(j * dilation_w) + offset_w; + + // height/width are validated on host (DeformConvValidateAndParse) so int is safe here. + T val = BilinearInterpolate(input_ptr, + static_cast(height), + static_cast(width), + h_im, + w_im); + + // Match CPU path: always interpolate then apply mask to keep branch-free hot loop. + data_col_ptr_base[kernel_idx * col_stride] = val * mask_val; + }; + + if constexpr (is_fixed) { +#pragma unroll + for (int i = 0; i < kH; ++i) { +#pragma unroll + for (int j = 0; j < kW; ++j) { + process_kernel_point(i, j); + } + } + } else { + for (int64_t i = 0; i < weight_h; ++i) { + for (int64_t j = 0; j < weight_w; ++j) { + process_kernel_point(i, j); + } + } + } + } +} + +// Bias add: Y[n,m,oh,ow] += B[m]. Layout NCHW. +template +__global__ void DeformConvAddBiasKernel( + T* Y, + const T* B, + DivMod spatial_div, // For dividing by (H * W) + DivMod channel_div, // For dividing by M (channel count) + int64_t total_elements) { + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += blockDim.x * gridDim.x) { + int64_t val = idx; + int64_t batch_channel_idx, pixel_idx; + + // 1. First decomposition: decompose idx into (batch_channel_idx, pixel_idx) + // Equivalent to: batch_channel_idx = idx / (H*W); pixel_idx = idx % (H*W); + spatial_div.divmod(val, batch_channel_idx, pixel_idx); + + int64_t batch_idx, channel_idx; + + // 2. Second decomposition: decompose batch_channel_idx into (batch_idx, channel_idx) + // Equivalent to: channel_idx = batch_channel_idx % M; + // We only need channel_idx (i.e. m) + channel_div.divmod(batch_channel_idx, batch_idx, channel_idx); + (void)batch_idx; // Only channel_idx is needed + + // channel_idx is what we need (i.e. m) + Y[idx] += DeformConvLdg(B + channel_idx); + } +} + +// Copy GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) into NCHW Y_g. +// src(c, j) with j = b_idx*output_image_size + pos -> dst[b_idx*M*output_image_size + c*output_image_size + pos]. +template +__global__ void CopyGemmOutputRowMajorToNCHWKernel( + const T* __restrict__ src, + T* __restrict__ dst, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel) { + int64_t total = cur_parallel * M_per_group * output_image_size; + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { + int64_t pos = idx % output_image_size; + int64_t c = (idx / output_image_size) % M_per_group; + int64_t b_idx = idx / (output_image_size * M_per_group); + int64_t j = b_idx * output_image_size + pos; + // src index for row-major: c * (cur_parallel * output_image_size) + j + dst[b_idx * M * output_image_size + c * output_image_size + pos] = src[c * (cur_parallel * output_image_size) + j]; + } +} + +} // namespace + +template +Status DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) { + int64_t total = N * M * out_h * out_w; + if (total <= 0) return Status::OK(); + + // 1. Prepare divisor + int64_t out_size = out_h * out_w; + + // 2. Create FastDivMod object (note: ensure int64_t version of DivMod is used here) + DivMod spatial_div(out_size); + DivMod channel_div(M); + + int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); + + // 3. Pass DivMod objects + DeformConvAddBiasKernel<<>>( + Y, + B, + spatial_div, + channel_div, + total); + return CUDA_CALL(cudaGetLastError()); +} + +template +Status DeformConvCopyGemmOutputRowMajorToNCHW( + cudaStream_t stream, + const T* gemm_output, + T* Y_g, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel) { + int64_t total = cur_parallel * M_per_group * output_image_size; + if (total <= 0) return Status::OK(); + int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); + CopyGemmOutputRowMajorToNCHWKernel<<>>( + gemm_output, Y_g, M, M_per_group, output_image_size, cur_parallel); + return CUDA_CALL(cudaGetLastError()); +} + +template +Status DeformConvIm2ColImpl( + cudaStream_t stream, + const T* input, + const T* offset, + const T* mask, + T* col_buffer, + int64_t parallel_imgs, + int64_t C, + int64_t H, + int64_t W, + int64_t kH, + int64_t kW, + int64_t out_h, + int64_t out_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t offset_group, + bool use_mask) { + const int64_t num_kernels = static_cast(C) * out_h * out_w * parallel_imgs; + if (num_kernels <= 0) { + return Status::OK(); + } + + const int64_t col_numel = static_cast(C) * kH * kW * parallel_imgs * out_h * out_w; + const bool use_64bit = (num_kernels > static_cast(std::numeric_limits::max())) || + (col_numel > static_cast(std::numeric_limits::max())); + + int blocks = GetGridSize(static_cast(num_kernels), kDeformConvThreadsPerBlock); + + auto launch = [&](auto kH_tag, auto kW_tag) { + constexpr int KH = decltype(kH_tag)::value; + constexpr int KW = decltype(kW_tag)::value; + if (use_64bit) { + DeformableIm2ColKernel<<>>( + num_kernels, input, offset, mask, H, W, kH, kW, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, C, offset_group, + DivMod(out_h), DivMod(out_w), DivMod(parallel_imgs), + DivMod(C / offset_group), use_mask, col_buffer); + } else { + DeformableIm2ColKernel<<>>( + static_cast(num_kernels), input, offset, mask, H, W, kH, kW, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, C, offset_group, + DivMod(static_cast(out_h)), + DivMod(static_cast(out_w)), + DivMod(static_cast(parallel_imgs)), + DivMod(static_cast(C / offset_group)), + use_mask, col_buffer); + } + }; + + if (kH == 1 && kW == 1) { + launch(DeformConvKSize<1>{}, DeformConvKSize<1>{}); + } else if (kH == 3 && kW == 3) { + launch(DeformConvKSize<3>{}, DeformConvKSize<3>{}); + } else if (kH == 5 && kW == 5) { + launch(DeformConvKSize<5>{}, DeformConvKSize<5>{}); + } else { + launch(DeformConvKSize<-1>{}, DeformConvKSize<-1>{}); + } + return CUDA_CALL(cudaGetLastError()); +} + +#define INST_DeformConvIm2ColImpl(T) \ + template Status DeformConvIm2ColImpl(cudaStream_t, const T*, const T*, const T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool) + +INST_DeformConvIm2ColImpl(float); +INST_DeformConvIm2ColImpl(double); +INST_DeformConvIm2ColImpl(half); +INST_DeformConvIm2ColImpl(BFloat16); + +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const double*, double*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const half*, half*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const BFloat16*, BFloat16*, int64_t, int64_t, int64_t, int64_t); + +template Status DeformConvAddBiasImpl(cudaStream_t, float*, const float*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, double*, const double*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, half*, const half*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFloat16*, int64_t, int64_t, int64_t, int64_t); + +// Delegate ORT type to CUDA type (e.g. MLFloat16 -> half); avoids repeating three identical specializations. +#define DELEGATE_DEFORM_CONV_IMPL(ORT_T, CUDA_T) \ + template <> \ + Status DeformConvIm2ColImpl(cudaStream_t stream, const ORT_T* input, \ + const ORT_T* offset, const ORT_T* mask, ORT_T* col_buffer, \ + int64_t parallel_imgs, int64_t C, int64_t H, int64_t W, \ + int64_t kH, int64_t kW, int64_t out_h, int64_t out_w, \ + int64_t pad_h, int64_t pad_w, int64_t stride_h, int64_t stride_w, \ + int64_t dilation_h, int64_t dilation_w, int64_t offset_group, bool use_mask) { \ + return DeformConvIm2ColImpl(stream, reinterpret_cast(input), \ + reinterpret_cast(offset), \ + mask ? reinterpret_cast(mask) : nullptr, \ + reinterpret_cast(col_buffer), \ + parallel_imgs, C, H, W, kH, kW, out_h, out_w, \ + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, \ + offset_group, use_mask); \ + } \ + template <> \ + Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ + const ORT_T* gemm_output, ORT_T* Y_g, \ + int64_t M, int64_t M_per_group, \ + int64_t output_image_size, int64_t cur_parallel) { \ + return DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ + reinterpret_cast(gemm_output), \ + reinterpret_cast(Y_g), \ + M, M_per_group, output_image_size, cur_parallel); \ + } \ + template <> \ + Status DeformConvAddBiasImpl(cudaStream_t stream, ORT_T * Y, const ORT_T* B, \ + int64_t N, int64_t M, int64_t out_h, int64_t out_w) { \ + return DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ + reinterpret_cast(B), N, M, out_h, out_w); \ + } + +// BFloat16 is not delegated: ORT's BFloat16 is the same type used in device code (ToCudaType in +// cuda_common.h), so the explicit instantiations above (INST_DeformConvIm2ColImpl(BFloat16), etc.) suffice. +DELEGATE_DEFORM_CONV_IMPL(MLFloat16, half) + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h new file mode 100644 index 0000000000000..0c26cb55311bc --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/status.h" + +namespace onnxruntime { +namespace cuda { + +// Adds bias to output: Y[n,m,oh,ow] += B[m]. Y is [N, M, out_h, out_w], B is [M]. +// T may be float, double, MLFloat16 (FP16), or BFloat16. +template +Status DeformConvAddBiasImpl( + cudaStream_t stream, + T* Y, + const T* B, + int64_t N, + int64_t M, + int64_t out_h, + int64_t out_w); + +// Copies GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) to NCHW slice at Y_g. +// T may be float, double, MLFloat16 (FP16), or BFloat16. +template +Status DeformConvCopyGemmOutputRowMajorToNCHW( + cudaStream_t stream, + const T* gemm_output, + T* Y_g, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel); + +// Fills col_buffer with deformable im2col. col_buffer layout: row-major [C*kH*kW, parallel_imgs*out_h*out_w]. +// Called once per batch block; caller does GEMM and bias. T may be float, double, MLFloat16 (FP16), or BFloat16. +template +Status DeformConvIm2ColImpl( + cudaStream_t stream, + const T* input, // [parallel_imgs, C, H, W] + const T* offset, // [parallel_imgs, offset_group*2*kH*kW, out_h, out_w] + const T* mask, // [parallel_imgs, offset_group*kH*kW, out_h, out_w] or nullptr + T* col_buffer, // [C*kH*kW, parallel_imgs*out_h*out_w] + int64_t parallel_imgs, + int64_t C, + int64_t H, + int64_t W, + int64_t kH, + int64_t kW, + int64_t out_h, + int64_t out_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t offset_group, + bool use_mask); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign.cc b/onnxruntime/core/providers/cuda/object_detection/roialign.cc index 71fb066c2898f..5d876ae5a2cc9 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cuda/object_detection/roialign.cc @@ -7,11 +7,37 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ +#define ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ RoiAlign, \ kOnnxDomain, \ 10, \ + 15, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + RoiAlign); + +#define ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + RoiAlign, \ + kOnnxDomain, \ + 16, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + RoiAlign); + +#define ADD_TYPED_ROIALIGN_OP_22(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RoiAlign, \ + kOnnxDomain, \ + 22, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -67,13 +93,22 @@ Status RoiAlign::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -#define SPECIALIZED_COMPUTE(T) \ - REGISTER_KERNEL_TYPED(T) \ +#define SPECIALIZED_COMPUTE(T) \ + ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \ + ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \ + ADD_TYPED_ROIALIGN_OP_22(T) \ template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; SPECIALIZED_COMPUTE(float) SPECIALIZED_COMPUTE(double) -// SPECIALIZED_COMPUTE(MLFloat16) +// MLFloat16 is available for RoiAlign op from version 16 (not version 10): +ADD_VERSIONED_TYPED_ROIALIGN_OP_16(MLFloat16) +ADD_TYPED_ROIALIGN_OP_22(MLFloat16) +template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; + +// BFloat16 is available for RoiAlign op from version 22: +ADD_TYPED_ROIALIGN_OP_22(BFloat16) +template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; } // namespace cuda }; // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu index 7acfd9d075461..87f4aba8e45b2 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu +++ b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu @@ -17,64 +17,72 @@ #include "roialign_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/accumulation_type.h" namespace onnxruntime { namespace cuda { template -__device__ T bilinear_interpolate( +__device__ AccumulationType_t bilinear_interpolate( const T* bottom_data, const int height, const int width, - T y, - T x, + AccumulationType_t y, + AccumulationType_t x, const bool is_mode_avg, const int index /* index for debug only*/) { + using TAcc = AccumulationType_t; + // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { + if (y < static_cast(-1.0f) || y > static_cast(height) || + x < static_cast(-1.0f) || x > static_cast(width)) { // empty - return 0; + return static_cast(0.0f); } - if (y <= 0) { - y = 0; + if (y <= static_cast(0.0f)) { + y = static_cast(0.0f); } - if (x <= 0) { - x = 0; + if (x <= static_cast(0.0f)) { + x = static_cast(0.0f); } - int y_low = (int)y; - int x_low = (int)x; + int y_low = static_cast(y); + int x_low = static_cast(x); int y_high; int x_high; if (y_low >= height - 1) { y_high = y_low = height - 1; - y = (T)y_low; + y = static_cast(y_low); } else { y_high = y_low + 1; } if (x_low >= width - 1) { x_high = x_low = width - 1; - x = (T)x_low; + x = static_cast(x_low); } else { x_high = x_low + 1; } - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; + TAcc ly = y - static_cast(y_low); + TAcc lx = x - static_cast(x_low); + TAcc hy = static_cast(1.0f) - ly; + TAcc hx = static_cast(1.0f) - lx; // do bilinear interpolation - T v1 = bottom_data[y_low * width + x_low]; - T v2 = bottom_data[y_low * width + x_high]; - T v3 = bottom_data[y_high * width + x_low]; - T v4 = bottom_data[y_high * width + x_high]; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + TAcc v1 = static_cast(bottom_data[y_low * width + x_low]); + TAcc v2 = static_cast(bottom_data[y_low * width + x_high]); + TAcc v3 = static_cast(bottom_data[y_high * width + x_low]); + TAcc v4 = static_cast(bottom_data[y_high * width + x_high]); + TAcc w1 = hy * hx; + TAcc w2 = hy * lx; + TAcc w3 = ly * hx; + TAcc w4 = ly * lx; - T val = is_mode_avg - ? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg - : max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max + TAcc val = is_mode_avg + ? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg + : max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max return val; } @@ -97,6 +105,8 @@ __global__ void RoIAlignForward( const bool half_pixel, const int64_t* batch_indices_ptr, const int64_t batch_size) { + using TAcc = AccumulationType_t; + for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -111,26 +121,27 @@ __global__ void RoIAlignForward( // If the index is out of range, we set the output to 0 for this RoI element. if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) { CUDA_KERNEL_ASSERT(false && "batch_indices values are out of range"); - top_data[index] = 0; + top_data[index] = static_cast(0.0f); continue; } // Do not using rounding; this implementation detail is critical - T roi_offset = half_pixel ? T(0.5) : T(0); - T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset; - T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset; - T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset; - T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; + const TAcc spatial_scale_acc = static_cast(spatial_scale); + const TAcc roi_offset = half_pixel ? static_cast(0.5f) : static_cast(0.0f); + TAcc roi_start_w = static_cast(offset_bottom_rois[0]) * spatial_scale_acc - roi_offset; + TAcc roi_start_h = static_cast(offset_bottom_rois[1]) * spatial_scale_acc - roi_offset; + TAcc roi_end_w = static_cast(offset_bottom_rois[2]) * spatial_scale_acc - roi_offset; + TAcc roi_end_h = static_cast(offset_bottom_rois[3]) * spatial_scale_acc - roi_offset; + + TAcc roi_width = roi_end_w - roi_start_w; + TAcc roi_height = roi_end_h - roi_start_h; if (!half_pixel) { // backward compatibility // Force malformed ROIs to be 1x1 - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); + roi_width = max(roi_width, static_cast(1.0f)); + roi_height = max(roi_height, static_cast(1.0f)); } - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + const TAcc bin_size_h = roi_height / static_cast(pooled_height); + const TAcc bin_size_w = roi_width / static_cast(pooled_width); const T* offset_bottom_data = bottom_data + static_cast((roi_batch_ind * channels + c) * height * width); @@ -138,26 +149,27 @@ __global__ void RoIAlignForward( // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio - : _Ceil(roi_height / pooled_height); // e.g., = 2 + : static_cast(_Ceil(roi_height / static_cast(pooled_height))); // e.g., = 2 int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : _Ceil(roi_width / pooled_width); + (sampling_ratio > 0) ? sampling_ratio : static_cast(_Ceil(roi_width / static_cast(pooled_width))); // We do average (integral) pooling inside a bin - const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + const int grid_count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + const TAcc count = static_cast(grid_count); // e.g. = 4 - T output_val = 0.; + TAcc output_val = static_cast(0.0f); bool max_flag = false; for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + const TAcc y = roi_start_h + static_cast(ph) * bin_size_h + + (static_cast(iy) + static_cast(0.5f)) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); + const TAcc x = roi_start_w + static_cast(pw) * bin_size_w + + (static_cast(ix) + static_cast(0.5f)) * bin_size_w / + static_cast(roi_bin_grid_w); - T val = bilinear_interpolate( + const TAcc val = bilinear_interpolate( offset_bottom_data, height, width, y, x, is_mode_avg, index); if (is_mode_avg) { @@ -176,7 +188,7 @@ __global__ void RoIAlignForward( output_val /= count; } - top_data[index] = output_val; + top_data[index] = static_cast(output_val); } } @@ -241,6 +253,8 @@ void RoiAlignImpl( SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double) +SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 9b23209953081..3dd50c1c03cbf 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -40,10 +40,70 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 18, 18, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 19, 20, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 21, 22, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 23, 23, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 24, 24, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Pad, \ kOnnxDomain, \ - 18, \ + 25, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -154,6 +214,11 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { effective_input_extents.push_back(extent); } + TArray input_offsets(dimension_count); + for (int32_t i = 0; i < dimension_count; ++i) { + input_offsets[i] = -(*p_slices)[i]; + } + TensorShape output_shape(output_dims); auto& output_tensor = *ctx->Output(0, output_shape); @@ -236,7 +301,8 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } - if (IsNCHWInputWithPaddingAlongHAndW(dimension_count, lower_pads, upper_pads)) { + if (mode_ != Mode::Wrap && + IsNCHWInputWithPaddingAlongHAndW(dimension_count, lower_pads, upper_pads)) { // If we have entered here, it means the input can only be 4-D (NCHW), 3-D (CHW), or 2-D (HW) // NCHW input @@ -282,6 +348,8 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { input_dims, input_strides, lower_pads, + TArray(effective_input_extents), + input_offsets, value, static_cast(mode_), reinterpret_cast::MappedType*>(input_tensor.Data()), diff --git a/onnxruntime/core/providers/cuda/tensor/pad_impl.cu b/onnxruntime/core/providers/cuda/tensor/pad_impl.cu index 6f530e800fdf2..6020769bf0ddf 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/pad_impl.cu @@ -7,19 +7,27 @@ namespace onnxruntime { namespace cuda { -// PadMode enum from core/providers/cpu/tensor/pad.h, cannot use that header because of nvcc/onnxruntime incompatibility +// PadMode enum from core/providers/cpu/tensor/padbase.h, cannot use that header because of nvcc/onnxruntime incompatibility enum class PadMode : int { Constant = 0, Reflect, - Edge + Edge, + Wrap }; +__device__ __forceinline__ int64_t WrapCoordinate(int64_t coord, int64_t extent) { + int64_t wrapped = coord % extent; + return wrapped < 0 ? wrapped + extent : wrapped; +} + template __global__ void _PadKernel( const size_t shape_rank, const TArray input_dims, const TArray input_strides, const TArray lower_pads, + const TArray effective_input_extents, + const TArray input_offsets, const T pad_value, const T* input_data, const TArray fdm_output_strides, @@ -33,33 +41,44 @@ __global__ void _PadKernel( int out_coord, r; fdm_output_strides[dim].divmod(output_index, out_coord, r); output_index = r; - int in_coord = 0; - if (out_coord < lower_pads[dim]) { - switch ((PadMode)pad_mode) { - case PadMode::Constant: - use_pad_value = true; - break; - case PadMode::Edge: - in_coord = 0; - break; - case PadMode::Reflect: - in_coord = lower_pads[dim] - out_coord; - break; - } - } else if (out_coord >= lower_pads[dim] + input_dims[dim]) { - switch ((PadMode)pad_mode) { - case PadMode::Constant: - use_pad_value = true; - break; - case PadMode::Edge: - in_coord = input_dims[dim] - 1; - break; - case PadMode::Reflect: - in_coord = input_dims[dim] - 2 - (out_coord - (lower_pads[dim] + input_dims[dim])); - break; - } + int64_t in_coord = 0; + if constexpr (pad_mode == static_cast(PadMode::Wrap)) { + const int64_t effective_input_extent = effective_input_extents[dim]; + const int64_t pre_pad = lower_pads[dim] + input_offsets[dim]; + const int64_t relative_coord = static_cast(out_coord) - pre_pad; + in_coord = input_offsets[dim] + WrapCoordinate(relative_coord, effective_input_extent); } else { - in_coord = out_coord - lower_pads[dim]; + if (out_coord < lower_pads[dim]) { + switch ((PadMode)pad_mode) { + case PadMode::Constant: + use_pad_value = true; + break; + case PadMode::Edge: + in_coord = 0; + break; + case PadMode::Reflect: + in_coord = lower_pads[dim] - out_coord; + break; + case PadMode::Wrap: + break; + } + } else if (out_coord >= lower_pads[dim] + input_dims[dim]) { + switch ((PadMode)pad_mode) { + case PadMode::Constant: + use_pad_value = true; + break; + case PadMode::Edge: + in_coord = input_dims[dim] - 1; + break; + case PadMode::Reflect: + in_coord = input_dims[dim] - 2 - (out_coord - (lower_pads[dim] + input_dims[dim])); + break; + case PadMode::Wrap: + break; + } + } else { + in_coord = out_coord - lower_pads[dim]; + } } input_index += input_strides[dim] * in_coord; } @@ -136,6 +155,8 @@ void PadImpl( const TArray& input_dims, const TArray& input_strides, const TArray& lower_pads, + const TArray& effective_input_extents, + const TArray& input_offsets, const T pad_value, const int pad_mode, const T* input_data, @@ -149,17 +170,22 @@ void PadImpl( switch (pad_mode) { case 0: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, pad_value, input_data, fdm_output_strides, output_data, N); break; case 1: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, pad_value, input_data, fdm_output_strides, output_data, N); break; case 2: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, + pad_value, input_data, fdm_output_strides, output_data, N); + break; + case 3: + _PadKernel<<>>( + shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets, pad_value, input_data, fdm_output_strides, output_data, N); break; } @@ -211,6 +237,8 @@ void PadNCHWInputWithPaddingAlongHAndWImpl( template void PadImpl(cudaStream_t stream, const size_t shape_rank, \ const TArray& input_dims, const TArray& input_strides, \ const TArray& lower_pads, \ + const TArray& effective_input_extents, \ + const TArray& input_offsets, \ const T pad_value, \ const int pad_mode, \ const T* input_data, \ diff --git a/onnxruntime/core/providers/cuda/tensor/pad_impl.h b/onnxruntime/core/providers/cuda/tensor/pad_impl.h index dc700ea2304e9..96f158dd187fc 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/pad_impl.h @@ -32,6 +32,8 @@ void PadImpl( const TArray& input_dims, const TArray& input_strides, const TArray& lower_pads, + const TArray& effective_input_extents, + const TArray& input_offsets, const T pad_value, const int pad_mode, const T* input_data, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index cce90f3ef82be..e20cc9140916a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -268,7 +268,7 @@ inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { const fs::path path{main_graph.ModelPath()}; if (path.has_filename()) { - const auto model_name = path.filename().string(); + const auto model_name = PathToUTF8String(path.filename().native()); LOGS_DEFAULT(INFO) << "Model name is '" << model_name << "'"; // Ensure enough characters are hashed in case model names are too short diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 7baac6aa1f6d0..0b137c4674b00 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -166,7 +166,7 @@ Status CreateCtxNode(const GraphViewer& graph_viewer, } attr_ep_cache_context->set_s(engine_data_str); } else { - std::string engine_cache_filename = std::filesystem::path(engine_cache_path).filename().string(); + std::string engine_cache_filename = PathToUTF8String(std::filesystem::path(engine_cache_path).filename().native()); attr_ep_cache_context->set_s(engine_cache_filename); std::fstream engine_cache_file(engine_cache_path, std::ios::binary | std::ios::out); if (engine_cache_file.is_open()) { @@ -188,7 +188,7 @@ Status CreateCtxNode(const GraphViewer& graph_viewer, attr_onnx_filename->set_name(ONNX_MODEL_FILENAME); attr_onnx_filename->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_onnx_filename->set_s(std::filesystem::path(onnx_model_path).filename().string()); + attr_onnx_filename->set_s(PathToUTF8String(std::filesystem::path(onnx_model_path).filename().native())); attr_sdk_version->set_name(SDK_VERSION); attr_sdk_version->set_type(onnx::AttributeProto_AttributeType_STRING); diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 8c8e1879a2c6b..1d17a72641a4e 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -185,7 +185,7 @@ void BackendManager::TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVi model_blob_str = std::move(ss).str(); } } else { // External blob - model_blob_str = shared_context_.GetBinPath().filename().string(); + model_blob_str = PathToUTF8String(shared_context_.GetBinPath().filename().native()); } auto status = ep_ctx_handle_.AddOVEPCtxNodeToGraph(graph_body_viewer, diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 0e49c0f897bea..68aa9a157f4a2 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -269,7 +269,7 @@ Status CreateEPContextNodes(Model* model, } context_bin_path = context_bin_path + ToPathString("_qnn.bin"); - context_cache_name = std::filesystem::path(context_bin_path).filename().string(); + context_cache_name = PathToUTF8String(std::filesystem::path(context_bin_path).filename().native()); // If generate ctx.onnx with share_ep_context enabled, all ctx.onnx should point to the same ctx.bin if (share_ep_contexts) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc b/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc index fb76f2110cbc8..87340e5b3ebeb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc @@ -237,7 +237,7 @@ Serializer::Serializer(const ProfilingInfo& profiling_info, tracelogging_provider_ep_enabled_(tracelogging_provider_ep_enabled) { #ifdef QNN_SYSTEM_PROFILE_API_ENABLED std::filesystem::path output_fs_filepath(profiling_info.csv_output_filepath); - qnn_log_filename_ = output_fs_filepath.filename().string(); + qnn_log_filename_ = PathToUTF8String(output_fs_filepath.filename().native()); // Remove extension (assumed to be ".csv") then add "_qnn.log" size_t extension_start_idx = qnn_log_filename_.rfind("."); qnn_log_filename_ = qnn_log_filename_.substr(0, extension_start_idx); diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index 2a54bfea86e91..091b110d8c746 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -120,7 +120,7 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, attr_2->set_s(compute_capability); attr_3->set_name(ONNX_MODEL_FILENAME); attr_3->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_3->set_s(std::filesystem::path(onnx_model_path).filename().string()); + attr_3->set_s(PathToUTF8String(std::filesystem::path(onnx_model_path).filename().native())); attr_4->set_name(SOURCE); attr_4->set_type(onnx::AttributeProto_AttributeType_STRING); attr_4->set_s(kTensorrtExecutionProvider); diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 5277d64ad3611..38848e98509ba 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -107,7 +107,11 @@ class ComputeContextBase { // Get the logger. // inline const logging::Logger& Logger() const { +#if defined(ORT_USE_EP_API_ADAPTERS) + return ep_.GetEpLogger(); +#else return *ep_.GetLogger(); +#endif } // diff --git a/onnxruntime/core/providers/webgpu/controlflow/if.cc b/onnxruntime/core/providers/webgpu/controlflow/if.cc index 233d1f760383f..29b5e66d5075a 100644 --- a/onnxruntime/core/providers/webgpu/controlflow/if.cc +++ b/onnxruntime/core/providers/webgpu/controlflow/if.cc @@ -3,6 +3,10 @@ #include "core/providers/webgpu/controlflow/if.h" +#if defined(ORT_USE_EP_API_ADAPTERS) +#include "core/framework/error_code_helper.h" +#endif + using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; @@ -68,10 +72,20 @@ ONNX_OPERATOR_KERNEL_EX(If, .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), If); +#if !defined(ORT_USE_EP_API_ADAPTERS) Status If::Compute(OpKernelContext* ctx) const { // call the base CPU version. return onnxruntime::If::Compute(ctx); } +#else +Status If::CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + return ToStatusAndRelease(ep::Api().ep.CreateIfKernel(info, impl)); +} + +Status If::Compute(OpKernelContext* /*ctx*/) const { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "If operator should be handled by ORT core."); +} +#endif } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/controlflow/if.h b/onnxruntime/core/providers/webgpu/controlflow/if.h index 0755c5d33d7a3..0aa30e939eb1a 100644 --- a/onnxruntime/core/providers/webgpu/controlflow/if.h +++ b/onnxruntime/core/providers/webgpu/controlflow/if.h @@ -10,6 +10,8 @@ namespace onnxruntime { namespace webgpu { +#if !defined(ORT_USE_EP_API_ADAPTERS) + // Use the CPU implementation for the logic class If final : public onnxruntime::If { public: @@ -18,5 +20,16 @@ class If final : public onnxruntime::If { Status Compute(OpKernelContext* ctx) const override; }; +#else + +class If final : public OpKernel { + public: + If(const OpKernelInfo& info) : OpKernel(info) {} + + Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) override; + Status Compute(OpKernelContext* ctx) const override; +}; +#endif + } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc index 6d66a7308f1de..5f109bf73e3c5 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.cc +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -7,38 +7,48 @@ namespace onnxruntime { namespace webgpu { -bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || - (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::GPU) || - (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); -} - -common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { - size_t bytes = src.SizeInBytes(); +common::Status DataTransferImpl::CopyTensor(void const* src_data, + bool src_is_gpu, + void* dst_data, + bool dst_is_gpu, + size_t bytes) const { if (bytes > 0) { - void const* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - - auto& src_device = src.Location().device; - auto& dst_device = dst.Location().device; - - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { + if (dst_is_gpu) { + if (src_is_gpu) { // copy from GPU to GPU buffer_manager_.MemCpy(static_cast(const_cast(src_data)), - static_cast(dst_data), bytes); + static_cast(dst_data), + bytes); } else { // copy from CPU to GPU - buffer_manager_.Upload(const_cast(src_data), static_cast(dst_data), bytes); + buffer_manager_.Upload(const_cast(src_data), + static_cast(dst_data), + bytes); } - } else /* if (src_device.Type() == OrtDevice::GPU) */ { + } else { // copy from GPU to CPU - buffer_manager_.Download(static_cast(const_cast(src_data)), dst_data, bytes); + buffer_manager_.Download(static_cast(const_cast(src_data)), + dst_data, + bytes); } } return Status::OK(); } +bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || + (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::GPU) || + (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); +} + +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + return impl_.CopyTensor(src.DataRaw(), + src.Location().device.Type() == OrtDevice::GPU, + dst.MutableDataRaw(), + dst.Location().device.Type() == OrtDevice::GPU, + src.SizeInBytes()); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h index 0adf380149acf..e6ce92a7ca7a6 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.h +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -3,6 +3,7 @@ #pragma once +#include "core/common/status.h" #include "core/framework/data_transfer.h" #include "core/framework/execution_provider.h" @@ -11,9 +12,25 @@ namespace webgpu { class BufferManager; +// Low-level data transfer implementation that operates on raw pointers. +// Used by both DataTransfer (IDataTransfer subclass) and the C API data transfer wrapper. +class DataTransferImpl { + public: + DataTransferImpl(const BufferManager& buffer_manager) : buffer_manager_{buffer_manager} {}; + + common::Status CopyTensor(void const* src_data, + bool src_is_gpu, + void* dst_data, + bool dst_is_gpu, + size_t bytes) const; + + private: + const BufferManager& buffer_manager_; +}; + class DataTransfer : public IDataTransfer { public: - DataTransfer(const BufferManager& buffer_manager) : buffer_manager_{buffer_manager} {}; + DataTransfer(const BufferManager& buffer_manager) : impl_{buffer_manager} {}; ~DataTransfer() {}; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; @@ -21,7 +38,7 @@ class DataTransfer : public IDataTransfer { common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; private: - const BufferManager& buffer_manager_; + DataTransferImpl impl_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/ep/api.cc b/onnxruntime/core/providers/webgpu/ep/api.cc new file mode 100644 index 0000000000000..9eeb3d71df89f --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/api.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include + +#include "core/providers/webgpu/ep/factory.h" + +// To make symbols visible on macOS/iOS +#ifdef __APPLE__ +#define EXPORT_SYMBOL __attribute__((visibility("default"))) +#else +#define EXPORT_SYMBOL +#endif + +namespace onnxruntime { +namespace webgpu { +void CleanupWebGpuContexts(); +void CleanupKernelRegistries(); +} // namespace webgpu +} // namespace onnxruntime + +namespace google { +namespace protobuf { +void ShutdownProtobufLibrary(); +} // namespace protobuf +} // namespace google + +extern "C" { +// +// Public symbols +// +EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + const OrtLogger* default_logger, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + + // Manual init for the C++ API + onnxruntime::ep::ApiInit(ort_api_base); + + if (max_factories < 1) { + return onnxruntime::ep::Api().ort.CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + // Initialize the global default logger + ::onnxruntime::ep::adapter::LoggingManager::CreateDefaultLogger(default_logger); + + // Factory could use registration_name or define its own EP name. + std::unique_ptr factory = std::make_unique(); + + factories[0] = factory.release(); + *num_factories = 1; + + return nullptr; + + EXCEPTION_TO_RETURNED_STATUS_END +} + +EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + // STEP.1 - Release the factory + delete static_cast(factory); + + // STEP.2 - Clean up cached kernel registries + onnxruntime::webgpu::CleanupKernelRegistries(); + + // STEP.3 - Clean up WebGPU contexts + onnxruntime::webgpu::CleanupWebGpuContexts(); + + // STEP.4 - Destroy the global default logger wrapper + ::onnxruntime::ep::adapter::LoggingManager::DestroyDefaultLogger(); + + // STEP.5 - Shutdown protobuf library + google::protobuf::ShutdownProtobufLibrary(); + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +} // extern "C" diff --git a/onnxruntime/core/providers/webgpu/ep/ep.cc b/onnxruntime/core/providers/webgpu/ep/ep.cc new file mode 100644 index 0000000000000..6beb62b5cf074 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/ep.cc @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep.h" + +#include "factory.h" + +#include "core/framework/run_options.h" +#include "core/framework/kernel_registry.h" +#include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/session/plugin_ep/ep_kernel_registration.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" + +#include "ep/get_capability_utils.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +using onnxruntime::ep::Api; + +// Constructor +Ep::Ep(std::unique_ptr impl, Factory& factory, const OrtLogger& logger, const Config& config) + : onnxruntime::ep::adapter::Ep{std::move(impl), config.cpu_allocator, config.device_allocator}, + factory_{factory}, + logger_{logger}, + config_{config} { + // Initialize the execution provider's function table + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + GetKernelRegistry = GetKernelRegistryImpl; + Compile = nullptr; // Per-kernel EP does not use Compile + ReleaseNodeComputeInfos = nullptr; + GetPreferredDataLayout = GetPreferredDataLayoutImpl; + ShouldConvertDataLayoutForOp = ShouldConvertDataLayoutForOpImpl; + SetDynamicOptions = nullptr; // Not implemented + OnRunStart = OnRunStartImpl; + OnRunEnd = OnRunEndImpl; + CreateAllocator = CreateAllocatorImpl; + CreateSyncStreamForDevice = nullptr; // Not stream aware + GetCompiledModelCompatibilityInfo = nullptr; // Not a compiled EP + IsConcurrentRunSupported = IsConcurrentRunSupportedImpl; +} + +// OrtEp interface implementations +const char* ORT_API_CALL Ep::GetNameImpl(const OrtEp* this_ptr) noexcept { + const auto* ep = static_cast(this_ptr); + return ep->factory_.GetName(&ep->factory_); +} + +OrtStatus* ORT_API_CALL Ep::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + + auto& ep = *static_cast(static_cast(this_ptr)->EpImpl()); + Ort::ConstGraph ort_graph{graph}; + + // Get all nodes in the graph + std::vector all_nodes = ort_graph.GetNodes(); + + if (all_nodes.empty()) { + return nullptr; // No nodes to process + } + + std::vector candidate_nodes; + std::vector tentative_candidate_nodes; + + // For each node, check if we have a registered kernel for it + for (const auto& node : all_nodes) { + std::string ep_name = node.GetEpName(); + + if (ep_name == kWebGpuExecutionProvider) { + candidate_nodes.push_back(node); + continue; + } + + // Reject nodes already assigned to a different (non-CPU) EP + if (!ep_name.empty() && ep_name != kCpuExecutionProvider) { + continue; + } + + const OrtKernelDef* kernel_def = nullptr; + RETURN_IF_ERROR(Api().ep.EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def)); + + if (kernel_def == nullptr) { + LOGS(ep.GetEpLogger(), INFO) << "webgpu kernel not found in registries for Op type: " + << node.GetOperatorType() << " node name: " << node.GetName(); + continue; + } + + auto cpu_node_names = ep.GetForceCpuNodeNames(); + if (std::find(cpu_node_names.begin(), + cpu_node_names.end(), + node.GetName()) != cpu_node_names.end()) { + LOGS(ep.GetEpLogger(), INFO) << "Force CPU execution for node: " << node.GetName(); + continue; + } + + // + // The following code checks if the node is really supported by webgpu EP + // + +#define FALLBACK_TO_CPU_IF_EXIST_INPUT(idx) \ + if (inputs.size() > idx && inputs[idx] != nullptr) { \ + continue; \ + } + +#define FALLBACK_TO_CPU_IF_EXIST_OUTPUT(idx) \ + if (outputs.size() > idx && outputs[idx] != nullptr) { \ + continue; \ + } + + // Check for Attention + if (node.GetOperatorType() == "Attention" && node.GetDomain() == kMSDomain) { + const auto& inputs = node.GetInputs(); + const auto& outputs = node.GetOutputs(); + + // Current implementation does not support mask_index(input[3]), past(input[4]) and past_seq_len(input[6]) + FALLBACK_TO_CPU_IF_EXIST_INPUT(3); + FALLBACK_TO_CPU_IF_EXIST_INPUT(4); + FALLBACK_TO_CPU_IF_EXIST_INPUT(6); + + // Current implementation does not support present(output[1]) + FALLBACK_TO_CPU_IF_EXIST_OUTPUT(1); + + // If attribute past_present_share_buffer is set, fallback to CPU + bool has_past_present_share_buffer = false; + for (const auto& attr : node.GetAttributes()) { + if (attr.GetName() == "past_present_share_buffer") { + int64_t val = 0; + RETURN_IF_ERROR(attr.GetValue(val)); + if (val != 0) { + has_past_present_share_buffer = true; + } + break; + } + } + if (has_past_present_share_buffer) { + continue; + } + } + + candidate_nodes.push_back(node); + tentative_candidate_nodes.push_back(node); + } + + std::unordered_set cpu_preferred_nodes; + RETURN_IF_ERROR(onnxruntime::ep::GetCpuPreferredNodes(*ort_graph, + *graph_support_info, + static_cast(this_ptr)->GetOrtLogger(), + tentative_candidate_nodes, + cpu_preferred_nodes)); + + for (const auto& node : candidate_nodes) { + if (cpu_preferred_nodes.count(node) == 0) { + RETURN_IF_ERROR(Api().ep.EpGraphSupportInfo_AddSingleNode(graph_support_info, node)); + } + } + + return nullptr; + + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::GetKernelRegistryImpl( + _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + + *kernel_registry = nullptr; + + // For the WebGPU EP, delegate to the CreateKernelRegistry function + // which properly constructs a registry using only public APIs + auto* ep = static_cast(this_ptr); + + auto& webgpu_ep = *static_cast(ep->EpImpl()); + + *kernel_registry = *webgpu_ep.GetKernelRegistryImpl(); + return nullptr; + + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::GetPreferredDataLayoutImpl(_In_ OrtEp* this_ptr, + _Out_ OrtEpDataLayout* preferred_data_layout) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + // Delegate to the underlying WebGPU EP's GetPreferredLayout() + // DataLayout enum values map 1:1 to OrtEpDataLayout (NCHW=0, NHWC=1) + auto* ep = static_cast(this_ptr); + *preferred_data_layout = static_cast(ep->EpImpl()->GetPreferredLayout()); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::ShouldConvertDataLayoutForOpImpl(_In_ OrtEp* this_ptr, + _In_z_ const char* domain, + _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + // DataLayout enum values map 1:1 to OrtEpDataLayout (NCHW=0, NHWC=1) + auto* ep = static_cast(this_ptr); + auto result = ep->EpImpl()->ShouldConvertDataLayoutForOp(domain, op_type, + static_cast(target_data_layout)); + if (result.has_value()) { + *should_convert = result.value() ? 1 : 0; + } else { + *should_convert = -1; + } + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::OnRunStartImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + onnxruntime::RunOptions options{}; + // currently only option "gpu_graph_id" is used + auto graph_annotation_str = Api().ort.GetRunConfigEntry(run_options, kOrtRunOptionsConfigCudaGraphAnnotation); + if (graph_annotation_str != nullptr) { + auto status = options.config_options.AddConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation, graph_annotation_str); + if (!status.IsOK()) { + return Api().ort.CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } + } + auto* ep = static_cast(this_ptr); + auto status = ep->EpImpl()->OnRunStart(options); + if (!status.IsOK()) { + return Api().ort.CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::OnRunEndImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* /*run_options*/, + _In_ bool sync_stream) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + auto* ep = static_cast(this_ptr); + auto status = ep->EpImpl()->OnRunEnd(sync_stream, {}); + if (!status.IsOK()) { + return Api().ort.CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Ep::IsConcurrentRunSupportedImpl(_In_ OrtEp* /*this_ptr*/, _Out_ bool* is_concurrent_run_supported) noexcept { + *is_concurrent_run_supported = false; + return nullptr; +} + +OrtStatus* ORT_API_CALL Ep::CreateAllocatorImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + auto* ep = static_cast(this_ptr); + Ort::ConstMemoryInfo ort_memory_info{memory_info}; + if (ort_memory_info.GetAllocatorType() == OrtReadOnlyAllocator) { + *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, ep->config_.initializer_allocator); + } else { + *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, ep->config_.device_allocator); + } + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/ep/ep.h b/onnxruntime/core/providers/webgpu/ep/ep.h new file mode 100644 index 0000000000000..815623025f8a4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/ep.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/providers/webgpu/webgpu_execution_provider.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +class Factory; + +/// +/// A bridge class between the EP API and the WebGPU EP implementation. +/// +class Ep : public onnxruntime::ep::adapter::Ep { + public: + struct Config { + AllocatorPtr cpu_allocator; + AllocatorPtr device_allocator; + AllocatorPtr initializer_allocator; + }; + + Ep(std::unique_ptr impl, Factory& factory, const OrtLogger& logger, const Config& config); + + inline const OrtLogger& GetOrtLogger() const noexcept { + return logger_; + } + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + + static OrtStatus* ORT_API_CALL GetKernelRegistryImpl( + _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) noexcept; + + static OrtStatus* ORT_API_CALL GetPreferredDataLayoutImpl(_In_ OrtEp* this_ptr, + _Out_ OrtEpDataLayout* preferred_data_layout) noexcept; + + static OrtStatus* ORT_API_CALL ShouldConvertDataLayoutForOpImpl(_In_ OrtEp* this_ptr, + _In_z_ const char* domain, + _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert) noexcept; + + static OrtStatus* ORT_API_CALL OnRunStartImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options) noexcept; + + static OrtStatus* ORT_API_CALL OnRunEndImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options, + _In_ bool sync_stream) noexcept; + + static OrtStatus* ORT_API_CALL IsConcurrentRunSupportedImpl(_In_ OrtEp* this_ptr, + _Out_ bool* is_concurrent_run_supported) noexcept; + + Factory& factory_; + const OrtLogger& logger_; + Config config_{}; +}; + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/ep/factory.cc b/onnxruntime/core/providers/webgpu/ep/factory.cc new file mode 100644 index 0000000000000..99dd0c68f6954 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/factory.cc @@ -0,0 +1,206 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "factory.h" +#include "ep.h" + +#include "core/framework/error_code_helper.h" +#include "core/graph/constants.h" + +#include "core/framework/execution_provider.h" +#include "core/framework/config_options.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/allocator.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +using onnxruntime::ep::Api; + +// Constructor +Factory::Factory() : OrtEpFactory{}, + default_memory_info_{WEBGPU_BUFFER, OrtMemoryInfoDeviceType_GPU, + 0, // vendor id + 0, // device id + OrtDeviceMemoryType_DEFAULT, + 0, // alignment + OrtDeviceAllocator}, + readonly_memory_info_{WEBGPU_BUFFER, OrtMemoryInfoDeviceType_GPU, + 0, // vendor id + 0, // device id + OrtDeviceMemoryType_DEFAULT, + 0, // alignment + OrtReadOnlyAllocator} { + ort_version_supported = ORT_API_VERSION; + + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; +} + +// Static C API implementations + +const char* ORT_API_CALL Factory::GetNameImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return kWebGpuExecutionProvider; +} + +const char* ORT_API_CALL Factory::GetVendorImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return "Microsoft"; +} + +uint32_t ORT_API_CALL Factory::GetVendorIdImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return 0; +} + +const char* ORT_API_CALL Factory::GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return "0.1.0"; +} + +OrtStatus* ORT_API_CALL Factory::GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + auto factory = static_cast(this_ptr); + + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (Api().ort.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // TODO: any metadata or options to add? + OrtEpDevice* ep_device = nullptr; + ORT_API_RETURN_IF_ERROR(Api().ep.CreateEpDevice(this_ptr, + &device, nullptr, nullptr, + &ep_device)); + ORT_API_RETURN_IF_ERROR(Api().ep.EpDevice_AddAllocatorInfo(ep_device, factory->default_memory_info_)); + ORT_API_RETURN_IF_ERROR(Api().ep.EpDevice_AddAllocatorInfo(ep_device, factory->readonly_memory_info_)); + ep_devices[num_ep_devices++] = ep_device; + } + } + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +OrtStatus* ORT_API_CALL Factory::CreateEpImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + if (num_devices == 0) { + return Api().ort.CreateStatus(ORT_INVALID_ARGUMENT, "No hardware devices provided to create WebGPU EP."); + } + + OrtKeyValuePairs* session_config_entries = nullptr; + ORT_API_RETURN_IF_ERROR(Api().ort.GetSessionOptionsConfigEntries(session_options, &session_config_entries)); + Ort::KeyValuePairs session_config_entries_holder(session_config_entries); // allow automatic release + + auto config_options = ConfigOptions{}; + const char* const* keys = nullptr; + const char* const* values = nullptr; + size_t num_entries = 0; + Api().ort.GetKeyValuePairs(session_config_entries, &keys, &values, &num_entries); + for (size_t i = 0; i < num_entries; ++i) { + auto status = config_options.AddConfigEntry(keys[i], values[i]); + if (!status.IsOK()) { + return Api().ort.CreateStatus((OrtErrorCode)status.Code(), status.ErrorMessage().c_str()); + } + } + + auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(config_options); + auto webgpu_ep = webgpu_ep_factory->CreateProvider(*session_options, *logger); + static_cast(webgpu_ep.get())->SetEpLogger(logger); + auto factory = static_cast(this_ptr); + const int context_id = webgpu_ep->GetDeviceId(); + Ep::Config webgpu_ep_config{ + CPUAllocator::DefaultInstance(), // CPU allocator + std::make_shared(WebGpuContextFactory::GetContext(context_id).BufferManager(), false), // default device allocator + std::make_shared(WebGpuContextFactory::GetContext(context_id).InitializerBufferManager(), true), // initializer device allocator + }; + *ep = new Ep(std::move(webgpu_ep), *factory, *logger, webgpu_ep_config); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +void ORT_API_CALL Factory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + delete static_cast(ep); +} + +OrtStatus* ORT_API_CALL Factory::CreateAllocatorImpl( + OrtEpFactory* /*this_ptr*/, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Ort::ConstMemoryInfo ort_memory_info{memory_info}; + + if (ort_memory_info.GetAllocatorType() != OrtDeviceAllocator || + ort_memory_info.GetDeviceId() != 0 || + ort_memory_info.GetAllocatorName() != WEBGPU_BUFFER) { + return Api().ort.CreateStatus(ORT_INVALID_ARGUMENT, + "Unsupported memory info for shared allocator."); + } + + *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, + [](const OrtMemoryInfo&) -> AllocatorPtr { + return std::make_shared(WebGpuContextFactory::DefaultContext().BufferManager(), false); + }); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +void ORT_API_CALL Factory::ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, OrtAllocator* allocator) noexcept { + onnxruntime::ep::adapter::Allocator* ptr = static_cast(allocator); + delete ptr; +} + +OrtStatus* ORT_API_CALL Factory::CreateDataTransferImpl( + OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + *data_transfer = OrtWebGpuCreateDataTransfer(); // TODO(fs-eire): pass context id if needed + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +bool ORT_API_CALL Factory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; // Default: not stream aware +} + +OrtStatus* ORT_API_CALL Factory::CreateSyncStreamForDeviceImpl( + OrtEpFactory* /*this_ptr*/, + const OrtMemoryDevice* /*memory_device*/, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** stream) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + *stream = nullptr; + return Api().ort.CreateStatus(ORT_NOT_IMPLEMENTED, + "CreateSyncStreamForDevice is not implemented for this EP factory."); + EXCEPTION_TO_RETURNED_STATUS_END +} + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/ep/factory.h b/onnxruntime/core/providers/webgpu/ep/factory.h new file mode 100644 index 0000000000000..f23b3871ebc60 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/factory.h @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "ep.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +/// +/// A bridge class between the EP API and the WebGPU EP Factory implementation. +/// +class Factory : public OrtEpFactory { + private: + // Static C API implementations + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* this_ptr, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl( + OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* this_ptr, OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl( + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept; + + Ort::MemoryInfo default_memory_info_; + Ort::MemoryInfo readonly_memory_info_; // used for initializers + + public: + Factory(); + ~Factory() = default; +}; + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/generator/range.cc b/onnxruntime/core/providers/webgpu/generator/range.cc index 3fa062f327ba2..a84660a020fed 100644 --- a/onnxruntime/core/providers/webgpu/generator/range.cc +++ b/onnxruntime/core/providers/webgpu/generator/range.cc @@ -79,36 +79,40 @@ template class Range; template class Range; template class Range; -void RegisterRangeKernels(KernelRegistry& kernel_registry, bool enable_int64) { - // Helper lambda to create kernel - auto create_range_kernel_info = [](auto type_tag) { - using T = decltype(type_tag); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { - out = std::make_unique>(info); - return Status::OK(); - }; - - return KernelCreateInfo( - KernelDefBuilder() - .SetName("Range") - .SetDomain(kOnnxDomain) - .SinceVersion(11) - .Provider(kWebGpuExecutionProvider) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .InputMemoryType(OrtMemTypeCPU, 0) - .InputMemoryType(OrtMemTypeCPU, 1) - .InputMemoryType(OrtMemTypeCPU, 2) - .Build(), - kernel_create_fn); - }; +namespace { +template +Status CreateRangeKernel(FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) { + out = std::make_unique>(info); + return Status::OK(); +} + +template +KernelCreateInfo CreateRangeKernelInfo() { + return KernelCreateInfo( + KernelDefBuilder() + .SetName("Range") + .SetDomain(kOnnxDomain) + .SinceVersion(11) + .Provider(kWebGpuExecutionProvider) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 0) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .Build(), + CreateRangeKernel); +} + +} // namespace + +void RegisterRangeKernels(KernelRegistry& kernel_registry, bool enable_int64) { // Always register float and int32_t - ORT_THROW_IF_ERROR(kernel_registry.Register(create_range_kernel_info(float{}))); - ORT_THROW_IF_ERROR(kernel_registry.Register(create_range_kernel_info(int32_t{}))); + ORT_THROW_IF_ERROR(kernel_registry.Register(CreateRangeKernelInfo())); + ORT_THROW_IF_ERROR(kernel_registry.Register(CreateRangeKernelInfo())); // Register int64_t only if int64 support is enabled if (enable_int64) { - ORT_THROW_IF_ERROR(kernel_registry.Register(create_range_kernel_info(int64_t{}))); + ORT_THROW_IF_ERROR(kernel_registry.Register(CreateRangeKernelInfo())); } } diff --git a/onnxruntime/core/providers/webgpu/math/einsum.cc b/onnxruntime/core/providers/webgpu/math/einsum.cc index bce173b1c62de..e17c0281c738f 100644 --- a/onnxruntime/core/providers/webgpu/math/einsum.cc +++ b/onnxruntime/core/providers/webgpu/math/einsum.cc @@ -325,9 +325,12 @@ Status EinsumProgram::GenerateShaderCode(ShaderHelper& shader) const { // Generate a WGSL loop header for reduction over this dimension // Format like: for(var j: u32 = 0; j < uniforms.input0_shape[1]; j++) {, given equation // "ij,jk->ik". + std::string shape_access = GetElementAt( + "uniforms.input" + std::to_string(lhs_term_index) + "_shape", + input_index, + static_cast(inputs[lhs_term_index].get().Rank())); reduce_ops_loop_headers.push_back("for(var " + symbol + ": u32 = 0; " + symbol + " < " + - "uniforms.input" + std::to_string(lhs_term_index) + - "_shape[" + std::to_string(input_index) + "]; " + + shape_access + "; " + symbol + "++) {"); // Add corresponding loop closing brace diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 6bb0f688bfdb7..2695c2800d37a 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -90,7 +90,7 @@ template KernelCreateInfo CreateCastKernelInfo(bool enable_int64) { const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 75453b991a0cd..c6178d44dba75 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -98,7 +98,7 @@ Status Concat::ComputeInternal(ComputeContext& context) const { } Prepare prepare; - ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), input_tensors, prepare)); + ORT_RETURN_IF_ERROR(PrepareForComputeImpl(&context.KernelContext(), input_tensors, prepare)); if (prepare.output_num_elements == 0) { return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 0dacd589cbba8..0d39b1ec9d35e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -108,7 +108,7 @@ template KernelCreateInfo CreateExpandVersionedKernelInfo(bool enable_int64) { const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; @@ -129,7 +129,7 @@ template KernelCreateInfo CreateExpandKernelInfo(bool enable_int64) { const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc index b3e5c7b4e8310..970c2d6bed7a3 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -60,7 +60,7 @@ Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { Status Gather::ComputeInternal(ComputeContext& context) const { Prepare p; - ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); + ORT_RETURN_IF_ERROR(PrepareForComputeImpl(&context.KernelContext(), p)); uint32_t data_size = onnxruntime::narrow(p.output_tensor->Shape().Size()); if (data_size == 0) { return Status::OK(); diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.cc b/onnxruntime/core/providers/webgpu/tensor/pad.cc index 0e77ec46bbddb..7a576c4b53ecf 100644 --- a/onnxruntime/core/providers/webgpu/tensor/pad.cc +++ b/onnxruntime/core/providers/webgpu/tensor/pad.cc @@ -49,7 +49,7 @@ Status Pad::ComputeInternal(ComputeContext& context) const { const auto pads_data = pads_tensor->DataAsSpan(); // Compute Pads by applying axes if specified otherwise copy the supplied pads. - PadBase::ComputePads(context.KernelContext(), data_rank, pads_data, pads); + PadBase::ComputePadsImpl(context.KernelContext(), data_rank, pads_data, pads); // Separate out any negative pads into the slices array PadBase::SeparateNegativeToSlices(pads, slices); diff --git a/onnxruntime/core/providers/webgpu/tensor/shape_op.cc b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc index b211d48dab1c9..09194aa9f4dbb 100644 --- a/onnxruntime/core/providers/webgpu/tensor/shape_op.cc +++ b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc @@ -3,11 +3,73 @@ #include "core/providers/webgpu/webgpu_kernel.h" #include "core/providers/webgpu/webgpu_supported_types.h" -#include "core/providers/cpu/tensor/shape_op.h" namespace onnxruntime { namespace webgpu { +#ifndef SHARED_PROVIDER +#include "core/common/common.h" +#include "core/common/narrow.h" +#include "core/framework/op_kernel.h" +#endif + +#include +#include + +class Shape final : public OpKernel { + public: + Shape(const OpKernelInfo& info) : OpKernel(info) { + info.GetAttrOrDefault("start", &start_index_, 0); + + if (start_index_ != 0) { + // "start" is provided and is non-default (default is 0) + needs_slicing_ = true; + } + + if (info.GetAttr("end", &end_index_).IsOK()) { + needs_slicing_ = true; + } + } + + // Takes a tensor as input and outputs an 1D int64 tensor + // containing the shape of the input tensor. + Status Compute(OpKernelContext* context) const override { + const auto* input = context->Input(0); + const TensorShape& input_shape = input->Shape(); + + int64_t rank = gsl::narrow_cast(input_shape.NumDimensions()); + + if (!needs_slicing_) { // vanilla use of Shape (no slicing) + Tensor* output = context->Output(0, {rank}); + input_shape.CopyDims(output->MutableData(), static_cast(rank)); + } else { // slicing is needed + int64_t true_start = start_index_; + int64_t true_end = end_index_; + + // Deal with negative(s) and clamp + true_start = true_start < 0 ? true_start + rank : true_start; + true_start = true_start < 0 ? 0 : ((true_start > rank) ? rank : true_start); + + true_end = true_end < 0 ? true_end + rank : true_end; + true_end = true_end < 0 ? 0 : ((true_end > rank) ? rank : true_end); + + auto slice_length = true_end - true_start; + Tensor* output = context->Output(0, {slice_length < 0 ? 0 : slice_length}); + + if (slice_length > 0) { + input_shape.CopyDims(output->MutableData(), onnxruntime::narrow(true_start), onnxruntime::narrow(slice_length)); + } + } + + return Status::OK(); + } + + private: + bool needs_slicing_ = false; + int64_t start_index_ = 0; + int64_t end_index_ = std::numeric_limits::max(); +}; + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Shape, kOnnxDomain, diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 230d172d7404e..5cc09501ab378 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -244,6 +244,11 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { return Status::OK(); } + // 1D transpose is identity - just copy the GPU buffer. + if (rank == 1) { + return Info().GetDataTransferManager().CopyTensor(*input_tensor, *output_tensor); + } + return DoTranspose(context, *p_perm, *input_tensor, *output_tensor); } diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc index 104fcf1812af8..3337448564bed 100644 --- a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc @@ -12,7 +12,7 @@ template KernelCreateInfo CreateUnsqueezeVersionedKernelInfo(bool enable_int64) { const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; @@ -47,7 +47,7 @@ template KernelCreateInfo CreateUnsqueezeKernelInfo(bool enable_int64) { const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; diff --git a/onnxruntime/core/providers/webgpu/tensor/upsample.cc b/onnxruntime/core/providers/webgpu/tensor/upsample.cc index fb406883ba4ba..8f51ed45004bf 100644 --- a/onnxruntime/core/providers/webgpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/webgpu/tensor/upsample.cc @@ -90,7 +90,7 @@ Status Upsample::ComputeInternal(ComputeContext& context) const { InlinedVector scales_array(input_dims.size()); // opset < 10 - if (OpKernel::Node().InputDefs().size() == 1) { + if (OpKernel::Node().SinceVersion() < 10) { scales_array = scales_; // Compute output shape from scales attributes and input dims ComputeOutputShape(scales_array, input_dims, output_dims); diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index 3560fba522cb8..428fe863ab61b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -82,7 +82,7 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { -> void { const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; - const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; + const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << u32(component_c" + x + " * 8)))"; shader.MainFunctionBody() << "let output_indices" << x << " = " << output_indices.OffsetToIndices("global_idx * 4 + " + x) << ";\n" << "let offset_a" << x << " = " << a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index c6255e6f352d9..b4d751ce3a2c0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -3,6 +3,7 @@ #include "core/providers/webgpu/webgpu_execution_provider.h" +#include #include #include #include @@ -445,8 +446,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ScatterElements); -std::unique_ptr RegisterKernels(bool enable_graph_capture = false, bool enable_int64 = false) { - auto kernel_registry = std::make_unique(); +std::unique_ptr RegisterKernels(bool enable_graph_capture, bool enable_int64) { + auto kernel_registry = std::make_unique(); static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list becoming empty after ops-reducing @@ -837,6 +838,72 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals return kernel_registry; } +#if defined(ORT_USE_EP_API_ADAPTERS) + +namespace { +std::mutex g_kernel_registry_mutex; +std::shared_ptr g_kernel_registry; +std::shared_ptr g_graph_capture_kernel_registry; +std::shared_ptr g_int64_kernel_registry; +} // namespace + +void CleanupKernelRegistries() { + std::lock_guard lock(g_kernel_registry_mutex); + g_kernel_registry.reset(); + g_graph_capture_kernel_registry.reset(); + g_int64_kernel_registry.reset(); +} +#endif + +std::shared_ptr GetKernelRegistry(bool enable_graph_capture, bool enable_int64) { + // kernel registry variables are defined differently based on build configuration + // + // - When building as a static library, use static local variable. This is because + // we don't have a reliable way to explicitly destroy the kernel registry after + // use. + // + // - When building as a shared library, use global variables. The cleanup will be performed + // when `ReleaseEpFactory` is called. + // + // Graph capture mode needs a separate kernel registry because contrib kernel registration + // differs based on enable_graph_capture, and enable_int64 is always true when + // enable_graph_capture is true. + if (enable_graph_capture) { +#if !defined(ORT_USE_EP_API_ADAPTERS) + static std::shared_ptr registry = RegisterKernels(true, true); + return registry; +#else + std::lock_guard lock(g_kernel_registry_mutex); + if (g_graph_capture_kernel_registry == nullptr) { + g_graph_capture_kernel_registry = RegisterKernels(true, true); + } + return g_graph_capture_kernel_registry; +#endif + } else if (enable_int64) { +#if defined(ORT_USE_EP_API_ADAPTERS) + std::lock_guard lock(g_kernel_registry_mutex); + if (g_int64_kernel_registry == nullptr) { + g_int64_kernel_registry = RegisterKernels(false, true); + } + return g_int64_kernel_registry; +#else + static std::shared_ptr registry = RegisterKernels(false, true); + return registry; +#endif + } else { +#if defined(ORT_USE_EP_API_ADAPTERS) + std::lock_guard lock(g_kernel_registry_mutex); + if (g_kernel_registry == nullptr) { + g_kernel_registry = RegisterKernels(false, false); + } + return g_kernel_registry; +#else + static std::shared_ptr registry = RegisterKernels(false, false); + return registry; +#endif + } +} + } // namespace webgpu using namespace webgpu; @@ -850,6 +917,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, preferred_data_layout_{config.data_layout}, force_cpu_node_names_{std::move(config.force_cpu_node_names)}, enable_graph_capture_{config.enable_graph_capture}, + // enable_int64_ is always true when enable_graph_capture_ is true enable_int64_{config.enable_graph_capture || config.enable_int64}, multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset}, prepack_allocator_{std::make_shared(context_.InitializerBufferManager(), false)} { @@ -882,6 +950,7 @@ std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { }; } +#if !defined(ORT_USE_EP_API_ADAPTERS) std::vector> WebGpuExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, @@ -973,20 +1042,7 @@ std::vector> WebGpuExecutionProvider::GetCapa return result; } -std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() const { - // Cache registries based on enable_graph_capture_ and enable_int64_ flags - // Note: enable_int64_ is always true when enable_graph_capture_ is true - if (enable_graph_capture_) { - static std::shared_ptr registry = webgpu::RegisterKernels(true, true); - return registry; - } else if (enable_int64_) { - static std::shared_ptr registry = webgpu::RegisterKernels(false, true); - return registry; - } else { - static std::shared_ptr registry = webgpu::RegisterKernels(false, false); - return registry; - } -} +#endif // !defined(ORT_USE_EP_API_ADAPTERS) std::unique_ptr WebGpuExecutionProvider::GetDataTransfer() const { return std::make_unique(BufferManager()); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index b5a6b5f167faf..b46d3f3cb45d2 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -4,6 +4,11 @@ #pragma once +#include +#include +#include +#include + #include "core/framework/execution_provider.h" #include "core/framework/session_options.h" #include "core/graph/constants.h" @@ -28,6 +33,9 @@ class GpuBufferAllocator; // Forward declare CapturedCommandInfo which is now defined in webgpu_context.h struct CapturedCommandInfo; + +// The actual implementation of kernel registration. +std::shared_ptr GetKernelRegistry(bool enable_graph_capture, bool enable_int64); } // namespace webgpu struct WebGpuExecutionProviderConfig { @@ -44,13 +52,21 @@ class WebGpuExecutionProvider : public IExecutionProvider { WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& config); ~WebGpuExecutionProvider() override; + inline auto GetKernelRegistryImpl() const { + return webgpu::GetKernelRegistry(enable_graph_capture_, enable_int64_); + } + +#if !defined(ORT_USE_EP_API_ADAPTERS) std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; - std::shared_ptr GetKernelRegistry() const override; + std::shared_ptr GetKernelRegistry() const override { + return GetKernelRegistryImpl(); + } +#endif std::unique_ptr GetDataTransfer() const override; #if defined(__wasm__) std::unique_ptr GetExternalDataLoader() const override; @@ -83,8 +99,18 @@ class WebGpuExecutionProvider : public IExecutionProvider { Status ReplayGraph(int graph_annotation_id) override; webgpu::BufferManager& BufferManager() const; AllocatorPtr PrepackAllocator() const { return prepack_allocator_; } + std::span GetForceCpuNodeNames() const { return force_cpu_node_names_; } uint32_t MultiRotaryCacheConcatOffset() const { return multi_rotary_cache_concat_offset_; } +#if defined(ORT_USE_EP_API_ADAPTERS) + inline onnxruntime::ep::adapter::Logger& GetEpLogger() const { + return *ep_logger_; + } + inline void SetEpLogger(const OrtLogger* logger) { + ep_logger_ = std::make_unique(logger); + } +#endif + private: bool IsGraphCaptureAllowed() const; void IncrementRegularRunCountBeforeGraphCapture(); @@ -114,6 +140,10 @@ class WebGpuExecutionProvider : public IExecutionProvider { // Allocator for prepacked weights (uses buffers without mapping) AllocatorPtr prepack_allocator_; + +#if defined(ORT_USE_EP_API_ADAPTERS) + std::unique_ptr ep_logger_; +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index fc2496f0c7b68..16899370e47f1 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -284,11 +284,11 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( // WebGPU DataTransfer implementation wrapper for the C API with lazy initialization struct WebGpuDataTransferImpl : OrtDataTransferImpl { - WebGpuDataTransferImpl(const OrtApi& ort_api_in) + WebGpuDataTransferImpl(const OrtApi& ort_api_in, int context_id) : ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()}, data_transfer_{nullptr}, - context_id_{0}, // Always use context 0 for Environment's data transfer + context_id_{context_id}, init_mutex_{} { ort_version_supported = ORT_API_VERSION; CanCopy = CanCopyImpl; // OrtDataTransferImpl::CanCopy callback @@ -327,9 +327,9 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { // If both are GPU, they must have the same device ID if (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) { - uint64_t src_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device); - uint64_t dst_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device); - if (src_device_id != dst_device_id) { + int src_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device); + int dst_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device); + if (src_device_id != impl.context_id_ || dst_device_id != impl.context_id_) { return false; // Cannot copy between different devices } } @@ -362,19 +362,40 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { auto& context = WebGpuContextFactory::DefaultContext(); - // Create the DataTransfer instance - // Note: The DataTransfer holds a const reference to BufferManager. The BufferManager's lifecycle + // Create the DataTransferImpl instance + // Note: The DataTransferImpl holds a const reference to BufferManager. The BufferManager's lifecycle // is managed by the WebGpuContext, which is stored in a static WebGpuContextFactory and persists // for the lifetime of the application, ensuring the reference remains valid. - impl.data_transfer_ = std::make_unique(context.BufferManager()); + impl.data_transfer_ = std::make_unique(context.BufferManager()); } } // Now perform the actual tensor copy for (size_t idx = 0; idx < num_tensors; ++idx) { - const OrtValue* src_tensor = src_tensors[idx]; - OrtValue* dst_tensor = dst_tensors[idx]; - auto status = impl.data_transfer_->CopyTensor(src_tensor->Get(), *dst_tensor->GetMutable()); +#if defined(ORT_USE_EP_API_ADAPTERS) + Ort::ConstValue src_value{src_tensors[idx]}; + const void* src_data = src_value.GetTensorRawData(); + size_t size = src_value.GetTensorSizeInBytes(); + bool src_is_gpu = src_value.GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU; + + Ort::UnownedValue dst_value{dst_tensors[idx]}; + void* dst_data = dst_value.GetTensorMutableRawData(); + bool dst_is_gpu = dst_value.GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU; +#else + const Tensor& src_tensor = src_tensors[idx]->Get(); + const void* src_data = src_tensor.DataRaw(); + size_t size = src_tensor.SizeInBytes(); + bool src_is_gpu = src_tensor.Location().device.Type() == OrtDevice::GPU; + + Tensor& dst_tensor = *dst_tensors[idx]->GetMutable(); + void* dst_data = dst_tensor.MutableDataRaw(); + bool dst_is_gpu = dst_tensor.Location().device.Type() == OrtDevice::GPU; +#endif + auto status = impl.data_transfer_->CopyTensor(src_data, + src_is_gpu, + dst_data, + dst_is_gpu, + size); if (!status.IsOK()) { return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str()); } @@ -398,19 +419,23 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { const OrtApi& ort_api; const OrtEpApi& ep_api; - std::unique_ptr data_transfer_; // Lazy-initialized - int context_id_; // Track which context we're using - std::mutex init_mutex_; // Protects lazy initialization + std::unique_ptr data_transfer_; // Lazy-initialized + int context_id_; // Track which context we're using + std::mutex init_mutex_; // Protects lazy initialization }; -OrtDataTransferImpl* OrtWebGpuCreateDataTransfer() { +OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id /* = 0 */) { +#if defined(ORT_USE_EP_API_ADAPTERS) + return new WebGpuDataTransferImpl(onnxruntime::ep::Api().ort, context_id); +#else // Validate API version is supported const OrtApi* api = OrtApis::GetApi(ORT_API_VERSION); if (!api) { // API version not supported - return nullptr to indicate failure return nullptr; } - return new WebGpuDataTransferImpl(*api); + return new WebGpuDataTransferImpl(*api, context_id); +#endif } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h index 021e33ef25309..876a2e11d791a 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h @@ -22,6 +22,6 @@ struct WebGpuProviderFactoryCreator { // C API to create data transfer for WebGPU EP with lazy initialization // Context will be determined from tensors during the first CopyTensors call // Caller takes ownership of the returned OrtDataTransferImpl* -OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(); +OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id = 0); } // namespace onnxruntime diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 7cd02e5413407..f37c685cf2f28 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -618,7 +618,7 @@ Status Environment::CreateAndRegisterInternalEps() { Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path) { std::lock_guard lock{mutex_}; - std::string lib_file_name = std::filesystem::path(lib_path).filename().string(); + std::string lib_file_name = PathToUTF8String(std::filesystem::path(lib_path).filename().native()); Env::Default().GetTelemetryProvider().LogRegisterEpLibraryWithLibPath(registration_name, lib_file_name); std::vector internal_factories = {}; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 08b58f3de1a11..b873c95b496bb 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2612,7 +2612,7 @@ common::Status InferenceSession::Initialize() { // and log telemetry std::filesystem::path model_path = graph.ModelPath(); - std::string model_file_name = model_path.filename().string(); + std::string model_file_name = PathToUTF8String(model_path.filename().native()); bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); env.GetTelemetryProvider().LogSessionCreation( session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), @@ -4096,7 +4096,7 @@ void InferenceSession::LogAllSessions() { if (nullptr != model) { onnxruntime::Graph& graph = model->MainGraph(); std::filesystem::path model_path = graph.ModelPath(); - std::string model_file_name = model_path.filename().string(); + std::string model_file_name = PathToUTF8String(model_path.filename().native()); bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); std::string model_weight_type = session->GetWeightDataType(); std::string model_graph_hash = session->GetGraphHash(); diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 00ea8947b9dd7..06ff495327ecd 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -451,7 +451,7 @@ TEST(MatMul2Bits, Float32_2b_Accuracy4) { TestMatMul2BitsTyped(); } -#ifdef USE_WEBGPU +#if defined(USE_WEBGPU) && !defined(ORT_USE_EP_API_ADAPTERS) namespace { @@ -594,7 +594,7 @@ TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_DP4A) { RunTest2Bits(opts); } -#endif // USE_WEBGPU +#endif // defined(USE_WEBGPU) && !defined(ORT_USE_EP_API_ADAPTERS) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp new file mode 100644 index 0000000000000..f7e461c29843a --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include + +#include "mlas.h" +#include "bench_util.h" +#include "core/mlas/lib/mlasi.h" + +namespace { + +// Compare fused MLAS unary activation paths against unfused baselines for +// SiLU and exact GELU(erf). + +constexpr float kSiluMinValue = -20.0f; +constexpr float kSiluMaxValue = 20.0f; +constexpr float kGeluMinValue = -10.0f; +constexpr float kGeluMaxValue = 10.0f; +constexpr float kInvSqrt2 = 0.7071067811865475244f; +constexpr int64_t kFusedBytesPerElement = 2; +constexpr int64_t kSiluUnfusedBytesPerElement = 5; +constexpr int64_t kGeluUnfusedBytesPerElement = 7; + +struct DispatchedUnaryPathInfo { + int64_t bytes_per_element; + const char* label; +}; + +DispatchedUnaryPathInfo GetSiluDispatchPathInfo() { +#if defined(MLAS_TARGET_AMD64) + if (GetMlasPlatform().SiluKernelRoutine == MlasSiluKernelAvx512F) { + return {kFusedBytesPerElement, "avx512_fused"}; + } +#endif + + // The current non-AVX512 dispatch target falls back to the generic path, + // which materializes the logistic result before the final multiply. + return {kSiluUnfusedBytesPerElement, "generic_fallback"}; +} + +DispatchedUnaryPathInfo GetGeluErfDispatchPathInfo() { +#if defined(MLAS_TARGET_AMD64) + if (GetMlasPlatform().GeluErfKernelRoutine == MlasGeluErfKernelAvx512F) { + return {kFusedBytesPerElement, "avx512_fused"}; + } +#endif + + // The current non-AVX512 dispatch target falls back to the generic exact + // GELU(erf) implementation, which uses separate scale, erf, and final passes. + return {kGeluUnfusedBytesPerElement, "generic_fallback"}; +} + +std::vector MakeInput(size_t n, float min_value, float max_value) { + auto data = RandomVectorUniform(n, min_value, max_value); + + if (!data.empty()) { + data[0] = 0.0f; + } + if (data.size() > 1) { + data[1] = -0.0f; + } + if (data.size() > 2) { + data[2] = -1.0f; + } + if (data.size() > 3) { + data[3] = 1.0f; + } + + return data; +} + +template +void RunDispatchedUnaryBenchmark(benchmark::State& state, + KernelFn&& kernel, + float min_value, + float max_value, + DispatchedUnaryPathInfo path_info) { + const auto n = static_cast(state.range(0)); + auto input = MakeInput(n, min_value, max_value); + std::vector output(n); + + state.SetLabel(path_info.label); + + kernel(input.data(), output.data(), n); + + for (auto _ : state) { + kernel(input.data(), output.data(), n); + benchmark::DoNotOptimize(output.data()); + benchmark::ClobberMemory(); + } + + const int64_t bytes_per_iteration = static_cast(n) * static_cast(sizeof(float)) * path_info.bytes_per_element; + state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); + state.SetBytesProcessed(static_cast(state.iterations()) * bytes_per_iteration); +} + +template +void RunUnfusedUnaryBenchmark(benchmark::State& state, + KernelFn&& kernel, + float min_value, + float max_value, + int64_t bytes_per_element) { + const auto n = static_cast(state.range(0)); + auto input = MakeInput(n, min_value, max_value); + std::vector output(n); + + kernel(input.data(), output.data(), n); + + for (auto _ : state) { + kernel(input.data(), output.data(), n); + benchmark::DoNotOptimize(output.data()); + benchmark::ClobberMemory(); + } + + const int64_t bytes_per_iteration = static_cast(n) * static_cast(sizeof(float)) * bytes_per_element; + state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); + state.SetBytesProcessed(static_cast(state.iterations()) * bytes_per_iteration); +} + +static void UnaryKernelArgs(benchmark::internal::Benchmark* b) { + for (int n : {1, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 511, 512, 1024, 4096, 16384, 65536, 262144}) { + b->Arg(n); + } +} + +void BM_SiluDispatch(benchmark::State& state) { + // Fused MLAS SiLU entry point. On supported platforms this may dispatch to a + // specialized implementation that combines the activation into a single + // kernel instead of exposing intermediate results. + RunDispatchedUnaryBenchmark(state, MlasComputeSilu, kSiluMinValue, kSiluMaxValue, GetSiluDispatchPathInfo()); +} + +void BM_SiluUnfusedDispatch(benchmark::State& state) { + // Unfused SiLU baseline: compute logistic(x) first and then multiply by x in + // a separate elementwise pass. + RunUnfusedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + MlasComputeLogistic(input, output, n); + MlasEltwiseMul(input, output, output, n); + }, + kSiluMinValue, + kSiluMaxValue, + kSiluUnfusedBytesPerElement); +} + +void BM_GeluErfDispatchExact(benchmark::State& state) { + // Fused MLAS GELU(erf) entry point using the exact erf-based formulation. + // On AMD64 this goes through the platform dispatch layer and may select an + // architecture-specific implementation. + RunDispatchedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + MlasComputeGeluErf(input, output, n); + }, + kGeluMinValue, + kGeluMaxValue, + GetGeluErfDispatchPathInfo()); +} + +void BM_GeluErfUnfusedExact(benchmark::State& state) { + // Unfused exact GELU(erf) baseline: scale by 1/sqrt(2), run erf, then apply the + // final 0.5 * x * (erf(x / sqrt(2)) + 1) transform in a separate pass. + RunUnfusedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + for (size_t i = 0; i < n; ++i) { + output[i] = input[i] * kInvSqrt2; + } + + MlasComputeErf(output, output, n); + + for (size_t i = 0; i < n; ++i) { + output[i] = 0.5f * input[i] * (output[i] + 1.0f); + } + }, + kGeluMinValue, + kGeluMaxValue, + kGeluUnfusedBytesPerElement); +} + +} // namespace + +BENCHMARK(BM_SiluDispatch)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_SiluUnfusedDispatch)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_GeluErfDispatchExact)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_GeluErfUnfusedExact)->Apply(UnaryKernelArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp b/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp index 14e05fd42538e..3762e30af352d 100644 --- a/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp +++ b/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp @@ -586,9 +586,11 @@ class MlasNeonFp16DequantB8BitTest : public MlasTestBase { // Reference dequantization for 8-bit packed data. // Uses explicit position-based indexing to match the packed layout exactly. + // Emulates the kernel's fp16 computation order: + // 1. neg_scaled_zp = fp16_round(-(scale * zp)) [once per block per column] + // 2. result = fp16_round(neg_scaled_zp + value * scale) [emulates fp16 fma] // // Packed layout for N>=8 group (8N-interleaved): - // For each K position, 8 consecutive bytes hold one value per column. // byte[groupStart + k * 8 + col] = value for K=k, column=col // // Packed layout for remainder N<8 (sequential): @@ -610,12 +612,14 @@ class MlasNeonFp16DequantB8BitTest : public MlasTestBase { for (size_t col = 0; col < 8; ++col) { const size_t absCol = n + col; const size_t srcIdx = groupStart + k * 8 + col; - const size_t dstIdx = srcIdx; // output has the same interleaved layout + const size_t dstIdx = srcIdx; const float value = static_cast(src[srcIdx]); const float scale = scales[absCol * BlkNum + block].ToFloat(); const float zp = static_cast( UseZeroPoints ? zero_points[absCol * BlkNum + block] : 128); - dst[dstIdx] = MLAS_FP16(value * scale - zp * scale); + // Emulate kernel: neg_scaled_zp rounded to fp16, then fma + const float neg_szp = MLAS_FP16(-(scale * zp)).ToFloat(); + dst[dstIdx] = MLAS_FP16(neg_szp + value * scale); } } } @@ -631,7 +635,8 @@ class MlasNeonFp16DequantB8BitTest : public MlasTestBase { const float scale = scales[n * BlkNum + block].ToFloat(); const float zp = static_cast( UseZeroPoints ? zero_points[n * BlkNum + block] : 128); - dst[dstIdx] = MLAS_FP16(value * scale - zp * scale); + const float neg_szp = MLAS_FP16(-(scale * zp)).ToFloat(); + dst[dstIdx] = MLAS_FP16(neg_szp + value * scale); } } } diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp new file mode 100644 index 0000000000000..e87768ce3e660 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -0,0 +1,285 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" + +#include +#include + +#if defined(MLAS_TARGET_AMD64) + +namespace { + +constexpr float kGeluMinValue = -10.0f; +constexpr float kGeluMaxValue = 10.0f; +constexpr float kSiluMinValue = -20.0f; +constexpr float kSiluMaxValue = 20.0f; + +constexpr float kGeluAbsoluteTolerance = 2e-6f; +constexpr float kGeluRelativeTolerance = 2e-5f; +constexpr float kSiluAbsoluteTolerance = 3e-5f; +constexpr float kSiluRelativeTolerance = 5e-5f; + +constexpr std::array kShortTestSizes = { + 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255}; + +constexpr std::array kLongTestSizes = { + 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, + 64, 65, 127, 128, 129, 255, 511, 512, 513, 1023, 1024, 1025, 4095}; + +bool IsGeluErfAvx512Dispatched() { + return GetMlasPlatform().GeluErfKernelRoutine == MlasGeluErfKernelAvx512F; +} + +bool IsSiluAvx512Dispatched() { + return GetMlasPlatform().SiluKernelRoutine == MlasSiluKernelAvx512F; +} + +bool UnaryOutputsMatch(float actual, float expected, float absolute_tolerance, float relative_tolerance, + bool check_signed_zero) { + if (std::isnan(expected)) { + return std::isnan(actual); + } + + if (std::isinf(expected)) { + return std::isinf(actual) && (std::signbit(actual) == std::signbit(expected)); + } + + if (check_signed_zero && actual == 0.0f && expected == 0.0f) { + return std::signbit(actual) == std::signbit(expected); + } + + const float diff = std::fabs(actual - expected); + if (diff <= absolute_tolerance) { + return true; + } + + const float scale = std::max(std::fabs(actual), std::fabs(expected)); + return scale > 0.0f && diff <= scale * relative_tolerance; +} + +const std::vector& GetGeluSpecialValues() { + static const std::vector values = { + std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + 0.0f, + -0.0f, + -10.0f, + -6.0f, + -3.0f, + -1.0f, + -0.5f, + 0.5f, + 1.0f, + 3.0f, + 6.0f, + 10.0f, + }; + + return values; +} + +const std::vector& GetSiluSpecialValues() { + static const std::vector values = { + std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::numeric_limits::max(), + -std::numeric_limits::max(), + 1.0e9f, + -1.0e9f, + 0.0f, + -0.0f, + -20.0f, + -10.0f, + -6.0f, + -3.0f, + -1.0f, + -0.5f, + 0.5f, + 1.0f, + 3.0f, + 6.0f, + 10.0f, + 20.0f, + }; + + return values; +} + +void FillInput(float* input, size_t n, float minimum_value, float maximum_value, + const std::vector& special_values, uint32_t seed) { + std::mt19937 generator(seed); + std::uniform_real_distribution distribution(minimum_value, maximum_value); + + for (size_t i = 0; i < n; ++i) { + input[i] = distribution(generator); + } + + const size_t special_count = std::min(n, special_values.size()); + for (size_t i = 0; i < special_count; ++i) { + input[i] = special_values[i]; + } +} + +class MlasComputeGeluErfAvx512Test : public MlasTestBase { + private: + MatrixGuardBuffer input_buffer_; + MatrixGuardBuffer generic_output_buffer_; + MatrixGuardBuffer public_output_buffer_; + MatrixGuardBuffer avx512_output_buffer_; + + void ExecuteCommon(const std::vector& sizes, size_t iterations) { + if (!IsGeluErfAvx512Dispatched()) { + GTEST_SKIP() << "AVX512F GELU(erf) dispatch is not available on this machine."; + } + + for (size_t size : sizes) { + for (size_t iteration = 0; iteration < iterations; ++iteration) { + float* input = input_buffer_.GetBuffer(size); + float* generic_output = generic_output_buffer_.GetBuffer(size); + float* public_output = public_output_buffer_.GetBuffer(size); + float* avx512_output = avx512_output_buffer_.GetBuffer(size); + + FillInput(input, size, kGeluMinValue, kGeluMaxValue, GetGeluSpecialValues(), + static_cast(size * 131u + iteration * 977u + 17u)); + + MlasGeluErfKernel(input, generic_output, size); + MlasComputeGeluErf(input, public_output, size); + MlasGeluErfKernelAvx512F(input, avx512_output, size); + + for (size_t i = 0; i < size; ++i) { + ASSERT_TRUE(UnaryOutputsMatch(public_output[i], generic_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "Public GELU(erf) mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", public=" << public_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(public_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "GELU(erf) mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], public_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "Public/API GELU(erf) dispatch mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", public=" << public_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - public_output[i]); + } + } + } + } + + public: + static const char* GetTestSuiteName() { + return "TranscendentalAvx512Gelu"; + } + + void ExecuteShort() override { + ExecuteCommon(std::vector(kShortTestSizes.begin(), kShortTestSizes.end()), 3); + } + + void ExecuteLong() override { + ExecuteCommon(std::vector(kLongTestSizes.begin(), kLongTestSizes.end()), 8); + } +}; + +class MlasComputeSiluAvx512Test : public MlasTestBase { + private: + MatrixGuardBuffer input_buffer_; + MatrixGuardBuffer generic_output_buffer_; + MatrixGuardBuffer public_output_buffer_; + MatrixGuardBuffer avx512_output_buffer_; + + void ExecuteCommon(const std::vector& sizes, size_t iterations) { + if (!IsSiluAvx512Dispatched()) { + GTEST_SKIP() << "AVX512F SiLU dispatch is not available on this machine."; + } + + for (size_t size : sizes) { + for (size_t iteration = 0; iteration < iterations; ++iteration) { + float* input = input_buffer_.GetBuffer(size); + float* generic_output = generic_output_buffer_.GetBuffer(size); + float* public_output = public_output_buffer_.GetBuffer(size); + float* avx512_output = avx512_output_buffer_.GetBuffer(size); + + FillInput(input, size, kSiluMinValue, kSiluMaxValue, GetSiluSpecialValues(), + static_cast(size * 149u + iteration * 991u + 31u)); + + MlasSiluKernel(input, generic_output, size); + MlasComputeSilu(input, public_output, size); + MlasSiluKernelAvx512F(input, avx512_output, size); + + for (size_t i = 0; i < size; ++i) { + ASSERT_TRUE(UnaryOutputsMatch(public_output[i], generic_output[i], + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Public Silu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", public=" << public_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(public_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Silu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], public_output[i], + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Public/API Silu dispatch mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", public=" << public_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - public_output[i]); + } + } + } + } + + public: + static const char* GetTestSuiteName() { + return "TranscendentalAvx512Silu"; + } + + void ExecuteShort() override { + ExecuteCommon(std::vector(kShortTestSizes.begin(), kShortTestSizes.end()), 3); + } + + void ExecuteLong() override { + ExecuteCommon(std::vector(kLongTestSizes.begin(), kLongTestSizes.end()), 8); + } +}; + +} // namespace + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } else { + count += MlasLongExecuteTests::RegisterLongExecute(); + count += MlasLongExecuteTests::RegisterLongExecute(); + } + return count; +}); + +#else + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool) { + return size_t{0}; +}); + +#endif diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 6078660bf0d6e..cd210f7bc70ba 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -1407,7 +1407,9 @@ TEST(NchwcOptimizerTests, UpsampleLinear) { } TEST(NchwcOptimizerTests, Activation) { - auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) { + auto test_case = [&](const std::string& activation_op_type, + const std::string& domain = kOnnxDomain, + int opset_version = 13) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({1, 48, 11, 15}); auto* conv1_output_arg = helper.MakeIntermediate(); @@ -1431,23 +1433,93 @@ TEST(NchwcOptimizerTests, Activation) { EXPECT_EQ(op_to_count["Add"], 1); }; - NchwcOptimizerTester(build_test_case, check_nchwc_graph); + NchwcOptimizerTester(build_test_case, check_nchwc_graph, opset_version); }; // Verify that the optimizer doesn't add reorders for these activations in - // this pattern. Relu/Sigmoid/Tanh are generally fusable with a + // this pattern. Relu/Sigmoid/Tanh/HardSigmoid are generally fusable with a // preceding convolution, but not here because the Conv output is consumed // both by the activation node and directly by the Add node. Gelu/QuickGelu // are also expected to remain as separate nodes. test_case("Relu"); test_case("Sigmoid"); test_case("Tanh"); + test_case("HardSigmoid"); + test_case("Gelu", kOnnxDomain, 20); test_case("Gelu", kMSDomain); test_case("QuickGelu", kMSDomain); } -TEST(NchwcOptimizerTests, ActivationSingleConsumerConvGuard) { - auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) { +TEST(NchwcOptimizerTests, ActivationSingleConsumerConvFusion) { + constexpr float kHardSigmoidAlpha = 0.125f; + constexpr float kHardSigmoidBeta = 0.625f; + + auto test_case = [&](const std::string& activation_op_type) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 48, 11, 15}); + auto* conv1_output_arg = helper.MakeIntermediate(); + auto* activation_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3}); + auto& activation_node = helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}); + if (activation_op_type == "HardSigmoid") { + activation_node.AddAttribute("alpha", kHardSigmoidAlpha); + activation_node.AddAttribute("beta", kHardSigmoidBeta); + } + helper.AddConvNode(activation_output_arg, output_arg, {16, 32, 1, 1}); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto& graph = session.GetGraph(); + auto op_to_count = CountOpsInGraph(graph); + + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count[activation_op_type], 0); + + size_t fused_conv_count = 0; + for (const auto& node : graph.Nodes()) { + if (node.OpType() != "Conv" || node.Domain() != kMSNchwcDomain) { + continue; + } + + const auto& attributes = node.GetAttributes(); + auto activation_it = attributes.find("activation"); + if (activation_it == attributes.end()) { + continue; + } + + fused_conv_count++; + EXPECT_EQ(activation_it->second.s(), activation_op_type); + + auto activation_params_it = attributes.find("activation_params"); + if (activation_op_type == "HardSigmoid") { + ASSERT_NE(activation_params_it, attributes.end()); + ASSERT_EQ(activation_params_it->second.floats_size(), 2); + EXPECT_FLOAT_EQ(activation_params_it->second.floats(0), kHardSigmoidAlpha); + EXPECT_FLOAT_EQ(activation_params_it->second.floats(1), kHardSigmoidBeta); + } else { + EXPECT_EQ(activation_params_it, attributes.end()); + } + } + + EXPECT_EQ(fused_conv_count, 1U); + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph); + }; + + for (const auto& activation_op_type : {"Relu", "Sigmoid", "Tanh", "HardSigmoid"}) { + test_case(activation_op_type); + } +} + +TEST(NchwcOptimizerTests, ActivationSingleConsumerConvNoFusion) { + auto test_case = [&](const std::string& activation_op_type, + const std::string& domain = kOnnxDomain, + int opset_version = 13) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({1, 48, 11, 15}); auto* conv1_output_arg = helper.MakeIntermediate(); @@ -1477,12 +1549,13 @@ TEST(NchwcOptimizerTests, ActivationSingleConsumerConvGuard) { } }; - NchwcOptimizerTester(build_test_case, check_nchwc_graph); + NchwcOptimizerTester(build_test_case, check_nchwc_graph, opset_version); }; // Gelu/QuickGelu must remain separate even with a single-consumer Conv input, // because the NCHWc Conv activation fuse guard only allows a fixed subset of // activations. + test_case("Gelu", kOnnxDomain, 20); test_case("Gelu", kMSDomain); test_case("QuickGelu", kMSDomain); } diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index 5d7eda39be271..03005e3a07386 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -114,43 +114,15 @@ RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ) { // DQ contrib op schema is not updated to support blocked quantization + // Rejection doesn't depend on type/zp/accuracy_level — keep representative combos only. RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ_Cuda) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - ; - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); } // Input2 @@ -179,7 +151,7 @@ RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); - auto scale_shape = std::vector{weight_shape}; + std::vector scale_shape = weight_shape; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); if constexpr (use_zp) { @@ -224,42 +196,13 @@ RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput_Cuda) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - ; - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); } // Input1 @@ -295,7 +238,7 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); - auto scale_shape = std::vector{weight_shape}; + std::vector scale_shape = weight_shape; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); if constexpr (use_zp) { @@ -353,52 +296,27 @@ TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch) { - // DQ contrib op schema is not updated to support blocked quantization + // One representative type combo per rejection scenario (type doesn't affect rejection logic). // block size too small RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); // block size not 2's power - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); // not axis 0 - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); // not rank 2 - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch_Cuda) { - // DQ contrib op schema is not updated to support blocked quantization + // One representative type combo per rejection scenario. // block size too small RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); // block size not 2's power - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); - ; - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); // not axis 0 - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); // not rank 2 - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); } // Input1 @@ -727,7 +645,7 @@ RunDQMatMulFP16Converted(const std::vector& input1_shape, utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); - auto scale_shape = std::vector{weight_shape}; + std::vector scale_shape = weight_shape; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); @@ -780,6 +698,602 @@ TEST(QDQTransformerTests, DQMatMulFP16ConvertedToMatMulNBits) { RunDQMatMulFP16Converted({12, 32}, {32, 16}, 0, 16, 0); } +// Per-tensor DQ -> MatMul conversion to MatMulNBits +// DQ has scalar scale (and optional scalar zero-point), no block_size attribute. +// Input1 +// | DQ(per-tensor) +// \ / +// MatMul +// | +// output +template +void RunDQMatMulPerTensorConverted(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + // Scalar scale (per-tensor) + auto* scale_arg = builder.MakeInitializer({}, {10.0f}); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer({}, std::vector{T(1, 0)}); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 0.01 /*per_sample_tolerance - higher due to blockwise accumulation reordering*/, + 5e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerTensorConvertedToMatMulNBits) { + // Per-tensor: cover both types and a non-divisible K case. + RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); + RunDQMatMulPerTensorConverted({12, 37}, {37, 16}, 0); +} + +// Per-channel (axis=1) DQ -> MatMul conversion to MatMulNBits +// DQ has 1D scale shape [N], axis=1, no block_size attribute. +// Input1 +// | DQ(per-channel axis=1) +// \ / +// MatMul +// | +// output +template +void RunDQMatMulPerChannelConverted(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + int64_t N = weight_shape[1]; + // 1D scale shape [N] for per-channel (axis=1) + auto* scale_arg = builder.MakeInitializer({N}, 8.0f, 12.0f); + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(1)), attrs); + + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(std::vector{N}, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerChannelConvertedToMatMulNBits) { + RunDQMatMulPerChannelConverted({12, 37}, {37, 16}, 0); +} + +// Negative test: per-axis axis=0 with 1D scale should NOT fuse +template +void RunDQMatMulPerAxisAxis0NotConverted(const std::vector& input1_shape, + const std::vector& weight_shape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + int64_t K = weight_shape[0]; + // 1D scale shape [K] for per-axis axis=0 — should NOT match + auto* scale_arg = builder.MakeInitializer({K}, 8.0f, 12.0f); + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), attrs); + + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn = [](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "0"); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerAxisAxis0NotConvertedToMatMulNBits) { + RunDQMatMulPerAxisAxis0NotConverted({12, 32}, {32, 16}); +} + +// Per-tensor DQ -> MatMul with configurable block_size session option +template +void RunDQMatMulPerTensorWithBlockSize(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t block_size_option) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + auto* scale_arg = builder.MakeInitializer({}, {10.0f}); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer({}, std::vector{T(1, 0)}); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + + // Verify the MatMulNBits node has the expected block_size attribute + for (const auto& node : session.GetGraph().Nodes()) { + if (node.OpType() == "MatMulNBits") { + auto& attrs = node.GetAttributes(); + auto bs_iter = attrs.find("block_size"); + ASSERT_NE(bs_iter, attrs.end()); + int64_t expected_bs = block_size_option > 0 ? block_size_option : 32; // default is 32 + EXPECT_EQ(bs_iter->second.i(), expected_bs); + } + } + }; + + std::function add_session_options_fn = + [block_size_option](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "0"); + std::ignore = sess_opts.config_options.AddConfigEntry( + kOrtSessionOptionsQDQMatMulNBitsBlockSize, + std::to_string(block_size_option).c_str()); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerTensorWithBlockSizeOption) { + // Default block_size (0 -> 32) + RunDQMatMulPerTensorWithBlockSize({12, 32}, {32, 16}, 0); + // Explicit block_size=16 + RunDQMatMulPerTensorWithBlockSize({12, 32}, {32, 16}, 16); +} + +// UINT8 per-tensor DQ -> MatMul -> MatMulNBits +// Tests shapes from real models including small dimensions (N=1, N=8). +template +void RunDQMatMulPerTensorUint8Converted(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, uint8_t(0), uint8_t(255)); + auto* dq_output = builder.MakeIntermediate(); + + // Scalar scale (per-tensor) + auto* scale_arg = builder.MakeInitializer({}, {0.05f}); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer({}, {uint8_t(128)}); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 0.01 /*per_sample_tolerance - higher due to blockwise accumulation reordering*/, + 5e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerTensorUint8ConvertedToMatMulNBits) { + RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 8}, 0); +} + +// --------------------------------------------------------------------------- +// DQ -> Gemm tests for MatMulNBits fusion +// --------------------------------------------------------------------------- + +// Input1 +// | DQ (4-bit weight) +// \ / +// Gemm +// | +// output +// Gemm has no bias, equivalent to MatMul. Should fuse to MatMulNBits. +template +typename std::enable_if || std::is_same_v, void>::type +RunDQGemmConvertedNoBias(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + std::vector scale_shape = weight_shape; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + auto* scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + } + + builder.AddNode("Gemm", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 2e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_NoBias) { + RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); +} + +// Input1 +// | DQ (4-bit weight) bias (float) +// \ / / +// Gemm +// | +// output +// Gemm has a direct (non-DQ) float bias. Should fuse to MatMulNBits with bias at input 5. +template +typename std::enable_if || std::is_same_v, void>::type +RunDQGemmConvertedWithBias(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + std::vector scale_shape = weight_shape; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + int64_t N = weight_shape[1]; + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + auto* scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + } + + auto* bias_arg = builder.MakeInitializer({N}, std::vector(static_cast(N), 0.5f)); + builder.AddNode("Gemm", {input_arg, dq_output, bias_arg}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 2e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_WithBias) { + RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); +} + +// Input1 +// | DQ (4-bit weight) DQ (bias) +// \ / / +// Gemm +// | +// output +// Gemm has a bias from DQ. Weight DQ fused into MatMulNBits, bias DQ stays alive, +// bias DQ output wired to MatMulNBits input 5. +template +typename std::enable_if || std::is_same_v, void>::type +RunDQGemmConvertedWithDQBias(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // Weight DQ + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + std::vector scale_shape = weight_shape; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + int64_t N = weight_shape[1]; + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + auto* scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + } + + // Bias DQ (int8 quantized bias -> float) + auto* bias_quantized = builder.MakeInitializer({N}, std::vector(static_cast(N), 5)); + auto* bias_scale = builder.MakeInitializer({}, std::vector{0.1f}); + auto* bias_zp = builder.MakeInitializer({}, std::vector{0}); + auto* bias_dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {bias_quantized, bias_scale, bias_zp}, {bias_dq_output}); + + builder.AddNode("Gemm", {input_arg, dq_output, bias_dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + // Weight DQ removed, bias DQ stays + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 2e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_WithDQBias) { + RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); +} + +// Negative test: DQ -> Gemm with transB=1 should NOT be fused. +TEST(QDQTransformerTests, DQGemmNotConvertedToMatMulNBits_TransB) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({12, 37}, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // With transB=1, Gemm transposes B at runtime: weight shape [N,K]=[12,37], transposed to [K,N]=[37,12]. + // DQ weight shape is [12,37] (N=12, K=37 after transpose). + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", static_cast(16)), dq_attrs); + auto* weight_arg = builder.MakeInitializer({12, 37}, Int4x2(Int4x2::min_val, 0), Int4x2(Int4x2::max_val, 0)); + auto* scales_arg = builder.MakeInitializer({1, 37}, 8.0f, 12.0f); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + + NodeAttributes gemm_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("transB", static_cast(1)), gemm_attrs); + builder.AddNode("Gemm", {input_arg, dq_output}, {output_arg}, "", &gemm_attrs); + }; + + auto check_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["Gemm"], 1); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5, 2e-5); +} + +// Negative test: DQ -> Gemm with alpha != 1.0 should NOT be fused. +TEST(QDQTransformerTests, DQGemmNotConvertedToMatMulNBits_Alpha) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({12, 37}, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", static_cast(16)), dq_attrs); + auto* weight_arg = builder.MakeInitializer({37, 12}, Int4x2(Int4x2::min_val, 0), Int4x2(Int4x2::max_val, 0)); + auto* scales_arg = builder.MakeInitializer({3, 12}, 8.0f, 12.0f); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + + NodeAttributes gemm_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("alpha", 2.0f), gemm_attrs); + builder.AddNode("Gemm", {input_arg, dq_output}, {output_arg}, "", &gemm_attrs); + }; + + auto check_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["Gemm"], 1); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5, 2e-5); +} + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace test diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index 2bf62b6944735..f732103842146 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -2171,5 +2171,105 @@ TEST_P(EinsumTransposeMatMulThreeInputsTest, EinsumTransposeMatMulThreeInputsTes INSTANTIATE_TEST_SUITE_P(EinsumTransposeMatMulThreeInputsTests, EinsumTransposeMatMulThreeInputsTest, testing::ValuesIn(case1)); +// Theme: High-rank contractions (WebGPU shader generation regression tests) + +// 5D contraction (Mamba-style chunked SSM state computation) +TEST(Einsum, ExplicitEinsumAs5DContraction) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "bcknd,bckns->bcnds"); + test.AddInput("x", {1, 1, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + test.AddInput("y", {1, 1, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + test.AddOutput("o", {1, 1, 2, 2, 2}, + {26.f, 32.f, 32.f, 40.f, 58.f, 68.f, 68.f, 80.f}); + test.Run(); +} + +// 5D x 5D contraction (contract middle dims, keep outer + inner) +TEST(Einsum, ExplicitEinsumAs5DContraction_abcde_abcdf) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abcde,abcdf->abef"); + test.AddInput("x", {1, 1, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + test.AddInput("y", {1, 1, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + test.AddOutput("o", {1, 1, 2, 2}, + {84.f, 100.f, 100.f, 120.f}); + test.Run(); +} + +// 5D x 5D contraction (contract 3 trailing dims) +TEST(Einsum, ExplicitEinsumAs5DContraction_abcde_afcde) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abcde,afcde->abf"); + test.AddInput("x", {1, 2, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f}); + test.AddInput("y", {1, 2, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f}); + test.AddOutput("o", {1, 2, 2}, + {204.f, 492.f, 492.f, 1292.f}); + test.Run(); +} + +// 5D reduction (reduce 2 of 5 axes) +TEST(Einsum, ExplicitEinsumAs5DReduction) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abcde->ace"); + test.AddInput("x", {2, 2, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}); + test.AddOutput("o", {2, 2, 2}, + {24.f, 28.f, 40.f, 44.f, 88.f, 92.f, 104.f, 108.f}); + test.Run(); +} + +// 6D x 6D contraction +TEST(Einsum, ExplicitEinsumAs6DContraction) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abcdef,abcdeg->abcfg"); + test.AddInput("x", {1, 1, 1, 1, 2, 2}, + {1.f, 2.f, 3.f, 4.f}); + test.AddInput("y", {1, 1, 1, 1, 2, 2}, + {1.f, 2.f, 3.f, 4.f}); + test.AddOutput("o", {1, 1, 1, 2, 2}, + {10.f, 14.f, 14.f, 20.f}); + test.Run(); +} + +// 6D reduction (reduce 3 of 6 axes) +TEST(Einsum, ExplicitEinsumAs6DReduction) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abcdef->adf"); + test.AddInput("x", {2, 2, 2, 2, 2, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, + 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, + 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f}); + test.AddOutput("o", {2, 2, 2}, + {112.f, 120.f, 144.f, 152.f, 368.f, 376.f, 400.f, 408.f}); + test.Run(); +} + +// 3-input bilinear form (x^T A y reduced to scalar) +TEST(Einsum, ExplicitEinsumAsBilinearFormToScalar) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "i,ij,j->"); + test.AddInput("x", {3}, {1.f, 2.f, 3.f}); + test.AddInput("y", {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + test.AddInput("z", {4}, {1.f, 2.f, 3.f, 4.f}); + test.AddOutput("o", {}, {500.f}); + test.Run(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py new file mode 100644 index 0000000000000..a6c4923cd961e --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Generate expected outputs for DeformConv tests using torchvision.ops.deform_conv2d. +Run with: .venv/bin/python onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py +Outputs C++-friendly std::vector initializer lists for pasting into deform_conv_op_test.cc + +Limitation: Uses symmetric padding only. PyTorch padding=(pad_h, pad_w) and ONNX pads +[pad_h, pad_w, pad_h, pad_w] are derived from a single (pad_h, pad_w) pair. Asymmetric +pads (e.g. pad_top != pad_bottom) would require PyTorch API support and are not generated. +""" + +import torch +import torchvision.ops + + +def _pair(x: int | tuple[int, int]) -> tuple[int, int]: + if isinstance(x, int): + return (x, x) + return x + + +def to_cpp_list(t: torch.Tensor, fmt="{:.6f}") -> str: + """Flatten tensor in NCHW order and format as C++ initializer list.""" + t = t.detach().float().contiguous() + return ", ".join(fmt.format(x) + "f" for x in t.flatten().tolist()) + + +def run_case( + name: str, + batch_sz: int, + n_in: int, + n_out: int, + n_weight_grps: int, + n_offset_grps: int, + kernel_h: int, + kernel_w: int, + stride: tuple[int, int] | int, + pad: tuple[int, int] | int, + dilation: tuple[int, int] | int, + in_h: int, + in_w: int, + seed: int = 42, +): + """Build inputs with seed, run deform_conv2d, print C++ snippets.""" + torch.manual_seed(seed) + stride_h, stride_w = _pair(stride) + pad_h, pad_w = _pair(pad) + dil_h, dil_w = _pair(dilation) + + out_h = (in_h + 2 * pad_h - (dil_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (kernel_w - 1) + 1)) // stride_w + 1 + + x = torch.rand(batch_sz, n_in, in_h, in_w, dtype=torch.float32) + offset = torch.randn(batch_sz, n_offset_grps * 2 * kernel_h * kernel_w, out_h, out_w, dtype=torch.float32) + mask = torch.randn(batch_sz, n_offset_grps * kernel_h * kernel_w, out_h, out_w, dtype=torch.float32) + weight = torch.randn(n_out, n_in // n_weight_grps, kernel_h, kernel_w, dtype=torch.float32) + bias = torch.randn(n_out, dtype=torch.float32) + + # Standard answer from torchvision + out = torchvision.ops.deform_conv2d( + x, + offset, + weight, + bias=bias, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dil_h, dil_w), + mask=mask, + ) + + # ONNX pads = [top, left, bottom, right] (symmetric: single pad_h, pad_w expanded) + pads_onnx = [pad_h, pad_w, pad_h, pad_w] + + print(f"// --- {name} (seed={seed}) ---") + print(f"// Shapes: X({batch_sz},{n_in},{in_h},{in_w}) W({n_out},{n_in // n_weight_grps},{kernel_h},{kernel_w})") + print(f"// stride=({stride_h},{stride_w}) pad=({pad_h},{pad_w}) dilation=({dil_h},{dil_w})") + print(f"// out_h={out_h} out_w={out_w}") + print() + print("std::vector X = {" + to_cpp_list(x) + "};") + print("std::vector W = {" + to_cpp_list(weight) + "};") + print("std::vector offset = {" + to_cpp_list(offset) + "};") + print("std::vector B = {" + to_cpp_list(bias) + "};") + print("std::vector mask = {" + to_cpp_list(mask) + "};") + print("std::vector expected_Y = {" + to_cpp_list(out) + "};") + print() + print( + "// Params: kernel_shape={" + f"{kernel_h}, {kernel_w}" + "}, stride={" + f"{stride_h}, {stride_w}" + "}, pads={" + + ", ".join(map(str, pads_onnx)) + + "}, dilations={" + + f"{dil_h}, {dil_w}" + + "}, group=" + + str(n_weight_grps) + + ", offset_group=" + + str(n_offset_grps) + ) + print() + return out + + +def main(): + print("// Generated by deform_conv_expected_gen.py (torchvision.ops.deform_conv2d)") + print() + + # Case 1: Same config as PyTorch TestDeformConv.get_fn_args (small batch for readability) + run_case( + "PyTorch get_fn_args style (batch=1)", + batch_sz=1, + n_in=6, + n_out=2, + n_weight_grps=2, + n_offset_grps=3, + kernel_h=3, + kernel_w=2, + stride=(2, 1), + pad=(1, 0), + dilation=(2, 1), + in_h=5, + in_w=4, + seed=42, + ) + + # Case 2: No mask (mask optional) - same config, then expected with mask=None + torch.manual_seed(42) + n_in, n_out = 6, 2 + n_weight_grps, n_offset_grps = 2, 3 + kH, kW = 3, 2 # noqa: N806 + stride_h, stride_w = 2, 1 + pad_h, pad_w = 1, 0 + dil_h, dil_w = 2, 1 + in_h, in_w = 5, 4 + batch_sz = 1 + out_h = (in_h + 2 * pad_h - (dil_h * (kH - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (kW - 1) + 1)) // stride_w + 1 + + x = torch.rand(batch_sz, n_in, in_h, in_w, dtype=torch.float32) + offset = torch.randn(batch_sz, n_offset_grps * 2 * kH * kW, out_h, out_w, dtype=torch.float32) + weight = torch.randn(n_out, n_in // n_weight_grps, kH, kW, dtype=torch.float32) + bias = torch.randn(n_out, dtype=torch.float32) + + out_no_mask = torchvision.ops.deform_conv2d( + x, + offset, + weight, + bias=bias, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dil_h, dil_w), + mask=None, + ) + print("// --- Same inputs, no mask (expected_Y when mask is omitted) ---") + print("std::vector expected_Y_no_mask = {" + to_cpp_list(out_no_mask) + "};") + print() + + # Case 3: groups=2, offset_group=2, non-zero offset (for GroupsWithNonZeroOffset test) + run_case( + "Groups with non-zero offset (batch=1, 2 groups)", + batch_sz=1, + n_in=4, + n_out=2, + n_weight_grps=2, + n_offset_grps=2, + kernel_h=2, + kernel_w=2, + stride=(1, 1), + pad=(0, 0), + dilation=(1, 1), + in_h=3, + in_w=3, + seed=123, + ) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc new file mode 100644 index 0000000000000..860c0d2f08b18 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -0,0 +1,948 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Unit tests for DeformConv (CPU and Cuda), aligned with PyTorch Vision deform_conv2d tests. +// Reference: https://github.com/pytorch/vision/blob/main/test/test_ops.py (TestDeformConv) + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/testdata/deform_conv_test_data.inc" +#include "test/unittest_util/conversion.h" + +#if defined(USE_CUDA) +#include "test/common/cuda_op_test_utils.h" +#endif + +namespace onnxruntime { +namespace test { + +namespace { + +// Parameters similar to PyTorch TestDeformConv::get_fn_args (smaller for speed). +struct DeformConvTestParams { + int64_t batch_sz; + int64_t n_in_channels; + int64_t n_out_channels; + int64_t n_weight_grps; + int64_t n_offset_grps; + std::vector kernel_shape; // {kH, kW} + std::vector stride; + std::vector pad; + std::vector dilation; + int64_t in_h; + int64_t in_w; +}; + +// Traits for type-specific DeformConv test behavior. +template +struct DeformConvTestTraits; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return v; } + static std::unordered_set ExcludedProviders() { + return {kTensorrtExecutionProvider, kNvTensorRTRTXExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-5f; } + static constexpr float DefaultAtol() { return 1e-5f; } +}; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return FloatsToMLFloat16s(v); } + static std::unordered_set ExcludedProviders() { + return {kCpuExecutionProvider, kNvTensorRTRTXExecutionProvider, kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-2f; } + static constexpr float DefaultAtol() { return 1e-2f; } +}; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { + return std::vector(v.begin(), v.end()); + } + static std::unordered_set ExcludedProviders() { + return {kTensorrtExecutionProvider, kNvTensorRTRTXExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; + } + static constexpr double DefaultRtol() { return 1e-8; } + static constexpr double DefaultAtol() { return 1e-8; } +}; + +#if defined(USE_CUDA) +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return FloatsToBFloat16s(v); } + static std::unordered_set ExcludedProviders() { + return {kCpuExecutionProvider, kNvTensorRTRTXExecutionProvider, kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-2f; } + static constexpr float DefaultAtol() { return 1e-2f; } +}; +#endif + +template +void RunDeformConvTest(const DeformConvTestParams& params, + const std::vector& X, + const std::vector& W, + const std::vector& offset, + const std::vector& B, + const std::vector* mask, + const std::vector& expected_Y, + int opset = 19, + decltype(DeformConvTestTraits::DefaultRtol()) rtol = DeformConvTestTraits::DefaultRtol(), + decltype(DeformConvTestTraits::DefaultAtol()) atol = DeformConvTestTraits::DefaultAtol(), + bool omit_bias = false) { + const int64_t kH = params.kernel_shape[0]; + const int64_t kW = params.kernel_shape[1]; + // ONNX pads format: [pad_top, pad_left, pad_bottom, pad_right] = [pad[0], pad[1], pad[2], pad[3]] + const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - + params.dilation[0] * (kH - 1) - 1) / + params.stride[0] + + 1; + const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - + params.dilation[1] * (kW - 1) - 1) / + params.stride[1] + + 1; + + OpTester test("DeformConv", opset); + test.AddAttribute("kernel_shape", params.kernel_shape); + test.AddAttribute("strides", params.stride); + test.AddAttribute("pads", params.pad); + test.AddAttribute("dilations", params.dilation); + test.AddAttribute("group", params.n_weight_grps); + test.AddAttribute("offset_group", params.n_offset_grps); + + const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; + const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; + const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; + const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; + + auto X_t = DeformConvTestTraits::Convert(X); + auto W_t = DeformConvTestTraits::Convert(W); + auto offset_t = DeformConvTestTraits::Convert(offset); + auto expected_Y_t = DeformConvTestTraits::Convert(expected_Y); + + test.AddInput("X", X_shape, X_t); + test.AddInput("W", W_shape, W_t); + test.AddInput("offset", offset_shape, offset_t); + if (omit_bias) { + test.AddOptionalInputEdge(); + } else { + auto B_t = DeformConvTestTraits::Convert(B); + test.AddInput("B", {params.n_out_channels}, B_t); + } + if (mask != nullptr) { + const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; + test.AddInput("mask", mask_shape, DeformConvTestTraits::Convert(*mask)); + } else { + test.AddOptionalInputEdge(); + } + + const float rtol_f = static_cast(rtol); + const float atol_f = static_cast(atol); + test.AddOutput("Y", Y_shape, expected_Y_t, false, rtol_f, atol_f); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", DeformConvTestTraits::ExcludedProviders()); +} + +// MinimalBilinear test: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). +// At (0,0) offset (0.5, 0.5) samples center of [1,2;3,4] -> 2.5. +template +void RunMinimalBilinearTest(int opset = 19, int min_cuda_arch = 0, bool omit_bias = false) { +#if defined(USE_CUDA) + if (min_cuda_arch > 0 && !HasCudaEnvironment(min_cuda_arch)) { + return; + } +#else + (void)min_cuda_arch; +#endif + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + // offset shape [N, 2*kH*kW, out_h, out_w] = [1, 2, 2, 2]: ch0=offset_h, ch1=offset_w (for kernel pt 0) + // Layout: offset[n,c,oh,ow]. Flattened (NCHW): [ch0@00, ch0@01, ch0@10, ch0@11, ch1@00, ch1@01, ch1@10, ch1@11] + // (0,0): (0.5, 0.5)->center of [1,2;3,4]->2.5; (0,1): (0,-1)->(0,0)->1; (1,0): (0,0)->3; (1,1): (0,0)->4 + std::vector offset = {0.5f, 0.f, 0.f, 0.f, 0.5f, -1.0f, 0.f, 0.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; + if (omit_bias) { + RunDeformConvTest(p, X, W, offset, {} /* B unused */, &mask, expected_Y, opset, + DeformConvTestTraits::DefaultRtol(), DeformConvTestTraits::DefaultAtol(), true); + } else { + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, opset, + DeformConvTestTraits::DefaultRtol(), DeformConvTestTraits::DefaultAtol(), false); + } +} +} // namespace + +// Minimal case: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). +TEST(DeformConvTest, MinimalBilinear) { + RunMinimalBilinearTest(); +} + +// Optional bias omitted: same as MinimalBilinear but B is not provided; output must match B=0. +TEST(DeformConvTest, OptionalBiasOmitted) { + RunMinimalBilinearTest(19, 0, true); +} + +// Minimal case FP16: Same as MinimalBilinear but in FP16 (CUDA-only). +#if defined(USE_CUDA) +TEST(DeformConvTest, MinimalBilinearFP16) { + RunMinimalBilinearTest(19, 530); +} + +// Minimal case BFloat16: Same as MinimalBilinear but in BFloat16 (CUDA-only, opset 22). +TEST(DeformConvTest, MinimalBilinearBFloat16) { + RunMinimalBilinearTest(22, 800); +} +#endif // defined(USE_CUDA) + +// Minimal case Double (FP64): Same as MinimalBilinear in double precision. +TEST(DeformConvTest, MinimalBilinearDouble) { + RunMinimalBilinearTest(); +} + +// Forward with mask and bias FP16 (CUDA-only; skip when CUDA not available). +#if defined(USE_CUDA) +TEST(DeformConvTest, ForwardWithMaskAndBiasFP16) { + int min_cuda_architecture = 530; // FP16 requires SM 5.3+ + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "DeformConv FP16: CUDA not available, skipping."; + return; + } + + DeformConvTestParams p = {}; + p.batch_sz = 2; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(p.batch_sz * p.n_in_channels * p.in_h * p.in_w); + const size_t w_size = static_cast(p.n_out_channels * (p.n_in_channels / p.n_weight_grps) * 2 * 2); + const size_t offset_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.5f, -0.5f}; + + const size_t y_size = static_cast(p.batch_sz * p.n_out_channels * out_h * out_w); + std::vector expected_Y(y_size); + for (int64_t b = 0; b < p.batch_sz; ++b) { + for (int64_t c = 0; c < p.n_out_channels; ++c) { + float val = (c % 2 == 0) ? 0.58f : -0.42f; + for (int64_t i = 0; i < out_h * out_w; ++i) { + expected_Y[b * p.n_out_channels * out_h * out_w + c * out_h * out_w + i] = val; + } + } + } + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} +#endif // defined(USE_CUDA) + +// With offset=0 and mask=1, Y = Conv(X,W) + B. Use small inputs and compute expected. +TEST(DeformConvTest, ForwardWithMaskAndBias) { + DeformConvTestParams p = {}; + p.batch_sz = 2; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(p.batch_sz * p.n_in_channels * p.in_h * p.in_w); + const size_t w_size = static_cast(p.n_out_channels * (p.n_in_channels / p.n_weight_grps) * 2 * 2); + const size_t offset_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); // zero offset -> regular grid sampling + std::vector mask(mask_size, 1.f); + std::vector B = {0.5f, -0.5f}; + + // With offset=0, mask=1: deform_conv equals grouped conv. Per ONNX, group 0 -> output ch 0, group 1 -> ch 1. + // Uniform X=0.1, W=0.1, 2x2 kernel -> 0.08 + B per channel; Y[:,0,:,:]=0.58, Y[:,1,:,:]=-0.42. + const size_t y_size = static_cast(p.batch_sz * p.n_out_channels * out_h * out_w); + std::vector expected_Y(y_size); + for (int64_t b = 0; b < p.batch_sz; ++b) { + for (int64_t c = 0; c < p.n_out_channels; ++c) { + float val = (c % 2 == 0) ? 0.58f : -0.42f; + for (int64_t i = 0; i < out_h * out_w; ++i) { + expected_Y[b * p.n_out_channels * out_h * out_w + c * out_h * out_w + i] = val; + } + } + } + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// No mask (optional): same as above but mask omitted; compare to run with ones mask via tolerance. +TEST(DeformConvTest, ForwardNoMask) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = 1 * 2 * 3 * 3; + const size_t w_size = 2 * 2 * 2 * 2; + const size_t offset_size = 1 * 2 * 2 * 2 * out_h * out_w; + const size_t y_size = 1 * 2 * out_h * out_w; + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector B(2, 0.f); + // No mask => mask=1. Zero offset => same as conv. Y = 4*2*0.1*0.1 = 0.08 per position. + std::vector expected_Y(y_size, 0.08f); + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + const std::vector X_shape = {p.batch_sz, p.n_in_channels, p.in_h, p.in_w}; + const std::vector W_shape = {p.n_out_channels, p.n_in_channels / p.n_weight_grps, 2, 2}; + const std::vector offset_shape = {p.batch_sz, p.n_offset_grps * 2 * 2 * 2, out_h, out_w}; + const std::vector Y_shape = {p.batch_sz, p.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W); + test.AddInput("offset", offset_shape, offset); + test.AddInput("B", {p.n_out_channels}, B); + test.AddOptionalInputEdge(); // no mask + test.AddOutput("Y", Y_shape, expected_Y, false, 1e-4f, 1e-4f); + std::unordered_set excluded = {kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +// Empty batch (N=0): allowed, same as Conv/ConvTranspose/Pool — output shape [0, oC, oH, oW]. +TEST(DeformConvTest, EmptyBatch) { + DeformConvTestParams p = {}; + p.batch_sz = 0; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X; + std::vector W = std::vector(2 * 2 * 2 * 2, 0.1f); + std::vector offset; + std::vector B(2, 0.f); + std::vector expected_Y; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + const std::vector X_shape = {0, p.n_in_channels, p.in_h, p.in_w}; + const std::vector W_shape = {p.n_out_channels, p.n_in_channels / p.n_weight_grps, 2, 2}; + const std::vector offset_shape = {0, p.n_offset_grps * 2 * 2 * 2, out_h, out_w}; + const std::vector Y_shape = {0, p.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W); + test.AddInput("offset", offset_shape, offset); + test.AddInput("B", {p.n_out_channels}, B); + test.AddOptionalInputEdge(); + test.AddOutput("Y", Y_shape, expected_Y); + std::unordered_set excluded = {kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +// Wrong offset channel count -> expect failure (like PyTorch test_wrong_sizes). +TEST(DeformConvTest, WrongOffsetShape) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 2 * 3 * 3, 0.1f); + std::vector W(2 * 2 * 2 * 2, 0.1f); + std::vector wrong_offset(1 * 2 * out_h * out_w); // wrong: only 2 channels instead of 8 + std::vector B(2, 0.f); + std::vector expected_Y(1 * 2 * out_h * out_w, 0.f); + + const std::vector offset_shape_wrong = {1, 2, out_h, out_w}; + const std::vector Y_shape_wrong = {1, 2, out_h, out_w}; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + test.AddInput("X", {1, 2, 3, 3}, X); + test.AddInput("W", {2, 2, 2, 2}, W); + test.AddInput("offset", offset_shape_wrong, wrong_offset); // invalid channels + test.AddInput("B", {2}, B); + test.AddOptionalInputEdge(); + test.AddOutput("Y", Y_shape_wrong, expected_Y); + std::unordered_set excluded = {kTensorrtExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectFailure, "Offset channel count must be offset_group * 2 * kH * kW", excluded); +} + +// Wrong mask channel count -> expect failure. +TEST(DeformConvTest, WrongMaskShape) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 2 * 3 * 3, 0.1f); + std::vector W(2 * 2 * 2 * 2, 0.1f); + const size_t offset_size = static_cast( + p.batch_sz * p.n_offset_grps * 2 * p.kernel_shape[0] * p.kernel_shape[1] * out_h * out_w); + std::vector offset(offset_size, 0.f); + std::vector B(2, 0.f); + std::vector wrong_mask(1 * 2 * out_h * out_w); // wrong: 2 instead of 4 + std::vector expected_Y(1 * 2 * out_h * out_w, 0.f); + + const std::vector mask_shape_wrong = {1, 2, out_h, out_w}; + const std::vector Y_shape_mask = {1, 2, out_h, out_w}; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + test.AddInput("X", {1, 2, 3, 3}, X); + test.AddInput("W", {2, 2, 2, 2}, W); + test.AddInput("offset", {1, 8, out_h, out_w}, offset); + test.AddInput("B", {2}, B); + test.AddInput("mask", mask_shape_wrong, wrong_mask); + test.AddOutput("Y", Y_shape_mask, expected_Y); + std::unordered_set excluded = {kTensorrtExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectFailure, "Mask channel count", excluded); +} + +// Opset 22 (same behavior, different opset). +TEST(DeformConvTest, Opset22) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + std::vector offset = {0.5f, 0.f, 0.f, 0.f, 0.5f, 0.f, 0.f, 0.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {2.5f, 2.f, 3.f, 4.f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 22); +} + +// Non-square kernel (kH != kW): 2x3 kernel, zero offset -> same as standard conv. +TEST(DeformConvTest, NonSquareKernel) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 3}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 4; + p.in_w = 5; + // ONNX output size: out_h = (4 - 1*(2-1) - 1)/1 + 1 = 3, out_w = (5 - 1*(3-1) - 1)/1 + 1 = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t x_size = static_cast(1 * 1 * 4 * 5); + const size_t w_size = static_cast(1 * 1 * 2 * 3); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 3 * out_h * out_w); // n_offset_grps * 2 * kH * kW * out_h * out_w + const size_t mask_size = static_cast(1 * 1 * 2 * 3 * out_h * out_w); // n_offset_grps * kH * kW * out_h * out_w + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // With offset=0, mask=1: each output = 6 * 0.1 * 0.1 = 0.06 (9 positions) + std::vector expected_Y(static_cast(out_h * out_w), 0.06f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Asymmetric stride (stride_h != stride_w): stride=(2,1), zero offset. +TEST(DeformConvTest, AsymmetricStride) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {2, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 5; + p.in_w = 4; + // out_h = (5 - 1*(2-1) - 1) / 2 + 1 = 2, out_w = (4 - 1*(2-1) - 1) / 1 + 1 = 3 + const int64_t out_h = 2; + const int64_t out_w = 3; + + const size_t x_size = static_cast(1 * 1 * 5 * 4); + const size_t w_size = static_cast(1 * 1 * 2 * 2); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + std::vector expected_Y(static_cast(out_h * out_w), 0.04f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// groups > 0 and non-zero offset; expected from deform_conv_expected_gen.py (seed=123). +TEST(DeformConvTest, GroupsWithNonZeroOffset) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + + std::vector X = {0.296112f, 0.516562f, 0.251671f, 0.688557f, 0.073972f, 0.866522f, 0.136580f, 0.102479f, 0.184056f, 0.726447f, 0.315254f, 0.687107f, 0.075635f, 0.196638f, 0.316412f, 0.401740f, 0.118568f, 0.827395f, 0.382084f, 0.660494f, 0.853572f, 0.593153f, 0.636725f, 0.982629f, 0.274495f, 0.658376f, 0.277542f, 0.857325f, 0.899328f, 0.039014f, 0.926823f, 0.738757f, 0.717884f, 0.705837f, 0.915650f, 0.433980f}; + std::vector W = {-1.182045f, -0.287745f, -0.604301f, 0.600237f, -1.420473f, -0.223828f, 0.430555f, -0.898857f, -0.017858f, 0.426403f, -0.765741f, -0.054514f, -0.732053f, 1.234742f, 1.186221f, -0.220099f}; + std::vector offset = {-0.388483f, -0.934346f, -0.499144f, -1.086653f, 0.962421f, 0.249208f, -0.484502f, -2.092915f, 0.098284f, -0.093507f, 0.266215f, -0.585035f, -0.343038f, -0.682148f, -0.988689f, -1.701830f, -1.220290f, 1.313853f, 1.053300f, 0.138805f, -0.204445f, -2.268529f, -0.913328f, -0.420363f, -0.659559f, -0.797928f, 0.183831f, 0.229347f, 0.617743f, -0.287578f, 0.821824f, 0.151178f, -0.044382f, 1.623557f, -2.322871f, 1.087831f, -0.063545f, -0.448641f, -1.278470f, -1.144004f, -0.152640f, 0.116741f, 0.440260f, -1.446546f, -0.558082f, -0.051696f, -0.908273f, 0.350683f, -0.394809f, 0.489227f, -0.216815f, -1.747165f, 1.722842f, 0.773806f, 0.404630f, -1.646126f, -0.595084f, -0.711218f, 0.622965f, -1.372881f, -0.128065f, -1.283835f, -0.290120f, 1.276741f}; + std::vector B = {0.983955f, 0.204512f}; + std::vector mask = {-0.031861f, -0.478956f, 0.766809f, 0.027468f, 0.047470f, -0.923866f, -1.060737f, -2.324446f, -2.062818f, 0.006375f, -0.989555f, 0.701609f, -0.982238f, 0.277031f, 0.645495f, -0.895681f, 0.492753f, -0.014078f, -0.274663f, -0.764091f, -0.587157f, 1.195165f, -1.209575f, -0.556008f, -0.077105f, 1.277377f, -1.459629f, -2.159528f, -0.706709f, -0.922245f, 3.895372f, -0.602697f}; + std::vector expected_Y = {0.971546f, 1.139858f, 0.452817f, 1.863882f, -0.565266f, 1.423187f, -2.462833f, -0.104923f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// Sampling out of bounds: offset pushes sampling to (-5,-5), BilinearInterpolate returns 0. +TEST(DeformConvTest, OutOfBoundsSampling) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + // out_h=out_w=2 (2x2 output), offset shape [1, 2, 2, 2] = 8 values. All (-5,-5) -> OOB -> 0 + std::vector offset = {-5.f, -5.f, -5.f, -5.f, -5.f, -5.f, -5.f, -5.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {0.f, 0.f, 0.f, 0.f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Dilation > 1: 2x2 kernel with dilation (2,2), zero offset -> 4 sample points with stride 2. +TEST(DeformConvTest, DilationGt1) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {2, 2}; + p.in_h = 5; + p.in_w = 5; + // out_h = (5 - 2*(2-1) - 1)/1 + 1 = 3, out_w = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t x_size = 25; + const size_t w_size = 4; + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Each output: 4 samples at (0,0),(0,2),(2,0),(2,2) -> 4 * 0.1 * 0.1 = 0.04 + std::vector expected_Y(static_cast(out_h * out_w), 0.04f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Decoupled groups: group=2, offset_group=1 (one offset map shared by all input channels). +TEST(DeformConvTest, DecoupledGroups) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(1 * 4 * 3 * 3); + const size_t w_size = static_cast(2 * 2 * 2 * 2); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f, 0.f}; + // Zero offset -> grouped conv. Per output ch: 2 in_ch * 4 kernel * 0.01 = 0.08 + std::vector expected_Y(static_cast(1 * 2 * out_h * out_w), 0.08f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Asymmetric padding: pads [top=1, left=0, bottom=0, right=1]; output 3x3, some positions have OOB samples. +TEST(DeformConvTest, AsymmetricPadding) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {1, 0, 0, 1}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + // out_h = (3+1+0-1*(2-1)-1)/1+1 = 3, out_w = (3+0+1-1-1)/1+1 = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(1 * 1 * 3 * 3, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Row 0: (0,0),(0,1) 2 valid -> 0.02; (0,2) only (0,2) in, (0,3) OOB -> 1 valid -> 0.01. Row 1/2: as before. + std::vector expected_Y = {0.02f, 0.02f, 0.01f, 0.04f, 0.04f, 0.02f, 0.04f, 0.04f, 0.02f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Tiny offset (near zero): offset (1e-6, 1e-6), sample ~(0,0) -> bilinear ≈ X[0,0]. Use 1x1 input for 1 output. +TEST(DeformConvTest, TinyOffset) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 1; + p.in_w = 1; + + std::vector X = {1.f}; + std::vector W = {1.f}; + std::vector offset = {1e-6f, 1e-6f}; + std::vector B = {0.f}; + std::vector mask = {1.f}; + std::vector expected_Y = {1.f}; // bilinear at (1e-6, 1e-6) ≈ 1 + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// Offset (0.5, 0.5) at each kernel point: sampling at (i+0.5, j+0.5) -> (0.5,0.5),(0.5,1.5),(1.5,0.5),(1.5,1.5). +// Only (0.5,0.5) is fully in-bounds for 2x2 input; others hit boundary (OOB gives 0). Result = 1.6875. +TEST(DeformConvTest, OffsetAtPixelCenters) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {0.25f, 0.25f, 0.25f, 0.25f}; + std::vector offset = { + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {1.6875f}; // op output: one center sample 2.5 + boundary samples + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Large batch (N=64) to trigger CUDA ComputeInternal chunking loop (b += n_parallel_imgs). +TEST(DeformConvTest, LargeBatchSize) { + const int64_t N = 64; + DeformConvTestParams p = {}; + p.batch_sz = N; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(N * 1 * 3 * 3); + const size_t offset_size = static_cast(N * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(N * 1 * 2 * 2 * out_h * out_w); + const size_t y_size = static_cast(N * 1 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + std::vector expected_Y(y_size, 0.04f); // 4 * 0.1 * 0.1 per position + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// group=1, offset_group=2: weights not grouped, offset/mask grouped. +TEST(DeformConvTest, Group1OffsetGroup2) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; // C must be divisible by offset_group + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(1 * 4 * 3 * 3); + const size_t w_size = static_cast(2 * 4 * 2 * 2); + const size_t offset_size = static_cast(1 * 2 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 2 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f, 0.f}; + // group=1: full conv. Each output: 4 in_ch * 4 kernel = 16 * 0.01 = 0.16 per channel, 2 out ch -> 0.16 each + std::vector expected_Y(static_cast(1 * 2 * out_h * out_w), 0.16f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Mask with zeros: exercises CUDA early-exit when mask_val == 0. +TEST(DeformConvTest, MaskWithZeros) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 1 * 3 * 3, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + std::vector offset(offset_size, 0.f); + // mask: (1, 4, 2, 2). Set all to 0 -> output should be 0. + std::vector mask(static_cast(1 * 1 * 2 * 2 * out_h * out_w), 0.f); + std::vector B = {0.f}; + std::vector expected_Y(static_cast(out_h * out_w), 0.f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Extreme aspect ratio (1x100): thin horizontal strip to verify coordinate indexing. +TEST(DeformConvTest, ExtremeAspectRatio) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 3}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 1; + p.in_w = 100; + // out_h = 1, out_w = (100 - 1*(3-1) - 1)/1 + 1 = 98 + const int64_t out_h = 1; + const int64_t out_w = 98; + + std::vector X(100, 0.1f); + std::vector W(1 * 1 * 1 * 3, 0.1f); + const size_t offset_size = static_cast(1 * 1 * 2 * 1 * 3 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 1 * 3 * out_h * out_w); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Each output: 3 * 0.1 * 0.1 = 0.03 + std::vector expected_Y(static_cast(out_h * out_w), 0.03f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// ONNX model data test: deform_conv_test_gen.py builds the ONNX model (via onnx.helper) +// and generates fixed inputs from torchvision (seed=123). This test is a model-loading/ +// integration smoke test that uses ORT-generated outputs from deform_conv_test.onnx as the reference. +TEST(DeformConvTest, OnnxModelTest) { + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", std::vector{2, 2}); + test.AddAttribute("strides", std::vector{1, 1}); + test.AddAttribute("pads", std::vector{0, 0, 0, 0}); + test.AddAttribute("dilations", std::vector{1, 1}); + test.AddAttribute("group", static_cast(2)); + test.AddAttribute("offset_group", static_cast(2)); + + test.AddInput("X", {1, 4, 3, 3}, kDeformConvOnnxTest_X); + test.AddInput("W", {2, 2, 2, 2}, kDeformConvOnnxTest_W); + test.AddInput("offset", {1, 16, 2, 2}, kDeformConvOnnxTest_offset); + test.AddInput("B", {2}, kDeformConvOnnxTest_B); + test.AddInput("mask", {1, 8, 2, 2}, kDeformConvOnnxTest_mask); + test.AddReferenceOutputs("testdata/deform_conv_test.onnx", 1e-4f); + + std::unordered_set excluded = {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc index 1eeb3683bc9aa..b2abe353693a2 100644 --- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" #include "test/common/trt_op_test_utils.h" @@ -906,5 +907,138 @@ TEST(RoiAlignTest, BatchIndicesNegative_CUDA) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); #endif } + +TEST(RoiAlignTest, Float16_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToFloat16({1.25f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoiAlignTest, Float16_Opset22) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 22); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToFloat16({1.25f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoiAlignTest, BFloat16_Opset22) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 22); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToBFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToBFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToBFloat16({1.25f})); + + test.SetOutputTolerance(0.05f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test half_pixel mode (default for Opset 16+) with Float16 on larger spatial dimensions. +// Uses 8x8 input (0..63), ROI [0,0,7,7], output 2x2, sampling_ratio=2. +// Expected values from ONNX reference implementation: {11.25, 14.75, 39.25, 42.75} +TEST(RoiAlignTest, Float16_HalfPixel_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + + std::vector X_val(64); + for (int i = 0; i < 64; ++i) X_val[i] = static_cast(i); + test.AddInput("X", {1, 1, 8, 8}, ToFloat16(X_val)); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 7., 7.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 2, 2}, ToFloat16({11.25f, 14.75f, 39.25f, 42.75f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test adaptive sampling (sampling_ratio=0) with Float16 on larger spatial dimensions. +// Uses 8x8 input (0..63), ROI [0,0,7,7], output 2x2, half_pixel mode. +// Adaptive: ceil(3.0/2)=2 samples per dim. +// Expected values from ONNX reference implementation: {11.39062, 14.875, 39.26562, 42.75} +TEST(RoiAlignTest, Float16_AdaptiveSampling_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 0); // adaptive + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + + std::vector X_val(64); + for (int i = 0; i < 64; ++i) X_val[i] = static_cast(i); + test.AddInput("X", {1, 1, 8, 8}, ToFloat16(X_val)); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 7., 7.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 2, 2}, + ToFloat16({11.39062f, 14.875f, 39.26562f, 42.75f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/pad_test.cc b/onnxruntime/test/providers/cpu/tensor/pad_test.cc index 9169f2e6b5ca9..990e4354c3626 100644 --- a/onnxruntime/test/providers/cpu/tensor/pad_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/pad_test.cc @@ -124,6 +124,37 @@ static void RunAllOpsetAllDomainPadTests( } } +#ifdef USE_CUDA +template +static void RunCudaOnlyOnnxOpsetPadTest( + int opset, + const std::vector& input_dims, + const std::vector& input, + const std::vector& pads, + T value, + const std::vector& output_dims, + const std::vector& output, + const std::string& mode = "constant") { + auto cuda_execution_provider = DefaultCudaExecutionProvider(); + if (cuda_execution_provider == nullptr) { + GTEST_SKIP() << "CUDA execution provider is not available"; + } + + OpTester test("Pad", opset); + if (mode != "constant") { + test.AddAttribute("mode", mode); + } + test.AddInput("data", input_dims, input); + test.AddInput("pads", {static_cast(pads.size())}, pads, true); + test.AddInput("value", {}, {value}, true); + test.AddOutput("output", output_dims, output); + + std::vector> execution_providers; + execution_providers.emplace_back(std::move(cuda_execution_provider)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + // Some of the tests can't run on TensorrtExecutionProvider because only constant mode and value 0 of "Pad" node is supported. // Those tests will fallback to other EP. @@ -199,6 +230,48 @@ TYPED_TEST(PadOpTest, Pad_Edge_1D) { "edge"); } +#ifdef USE_CUDA +TEST(PadOpTest, Pad_Edge_CudaOnly_MLFloat16_SupportedOpsets) { + const std::vector supported_opsets{18, 19, 20, 21, 22, 23, 24, 25}; + for (int opset : supported_opsets) { + SCOPED_TRACE(MakeString("opset: ", opset)); + RunCudaOnlyOnnxOpsetPadTest( + opset, + {3, 2}, + {MLFloat16(1.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f)}, + {0, 2, 0, 1}, + MLFloat16(0.0f), + {3, 5}, + {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(3.0f), MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(5.0f), MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(6.0f)}, + "edge"); + } +} + +TEST(PadOpTest, Pad_Wrap_CudaOnly_Float_SupportedOpsets) { + const std::vector supported_opsets{19, 20, 21, 22, 23, 24, 25}; + for (int opset : supported_opsets) { + SCOPED_TRACE(MakeString("opset: ", opset)); + RunCudaOnlyOnnxOpsetPadTest( + opset, + {3, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f}, + {0, 1, 0, 1}, + 0.0f, + {3, 4}, + {2.0f, 1.0f, 2.0f, 1.0f, + 4.0f, 3.0f, 4.0f, 3.0f, + 6.0f, 5.0f, 6.0f, 5.0f}, + "wrap"); + } +} +#endif + TYPED_TEST(PadOpTest, Pad_Constant_2D) { using T = TypeParam; RunAllOpsetAllDomainPadTests({2, 2}, @@ -1391,9 +1464,7 @@ TEST(PadOpTest, Pad_Wrap_NegativeFront_PositiveBack) { // Post-slice core: [4]; wrap 3 -> [4, 4, 4, 4] const std::vector expected_data = {4, 4, 4, 4}; - // CUDA registers only up to 18 and does not impl wrap mode - // so we force version to 19 to automatically exclude EPs that do not - // implement wrap mode similar to the above tests. + // Use opset 19 to exercise wrap mode, which is supported from Pad-19 onward. OpTester test("Pad", 19); test.AddInput("data", input_shape, input_data); test.AddInput("pads", {static_cast(pads.size())}, pads, true); diff --git a/onnxruntime/test/testdata/deform_conv_test.onnx b/onnxruntime/test/testdata/deform_conv_test.onnx new file mode 100644 index 0000000000000..b643014e44acb Binary files /dev/null and b/onnxruntime/test/testdata/deform_conv_test.onnx differ diff --git a/onnxruntime/test/testdata/deform_conv_test_data.inc b/onnxruntime/test/testdata/deform_conv_test_data.inc new file mode 100644 index 0000000000000..206d8517dd3e3 --- /dev/null +++ b/onnxruntime/test/testdata/deform_conv_test_data.inc @@ -0,0 +1,10 @@ +// Auto-generated by deform_conv_test_gen.py - do not edit + +#include + +static const std::vector kDeformConvOnnxTest_X = {0.296111941f, 0.516562283f, 0.251670718f, 0.68855679f, 0.0739724636f, 0.866521955f, 0.136579871f, 0.102479041f, 0.184056461f, 0.726446748f, 0.315253913f, 0.687106669f, 0.075635314f, 0.196638167f, 0.316411972f, 0.401740134f, 0.118568301f, 0.82739538f, 0.382084429f, 0.660493851f, 0.853571773f, 0.593153f, 0.636725366f, 0.982629359f, 0.274495304f, 0.658375621f, 0.277541935f, 0.857324839f, 0.899328232f, 0.0390138626f, 0.926822901f, 0.738757193f, 0.717883527f, 0.705837429f, 0.915649533f, 0.433980227f}; +static const std::vector kDeformConvOnnxTest_W = {-1.18204546f, -0.287744999f, -0.604300678f, 0.600236714f, -1.42047262f, -0.223827749f, 0.430554837f, -0.89885664f, -0.0178579595f, 0.426403075f, -0.765740693f, -0.0545141846f, -0.732052684f, 1.23474216f, 1.18622088f, -0.220098898f}; +static const std::vector kDeformConvOnnxTest_offset = {-0.388483077f, -0.934345901f, -0.499144107f, -1.08665264f, 0.962421f, 0.249208495f, -0.484502077f, -2.09291434f, 0.0982837752f, -0.0935074314f, 0.266214728f, -0.585035503f, -0.343037993f, -0.682147384f, -0.988689423f, -1.70183039f, -1.2202903f, 1.31385386f, 1.05329967f, 0.138805181f, -0.204444751f, -2.26852894f, -0.913327932f, -0.420362711f, -0.659559608f, -0.797927678f, 0.18383126f, 0.229347408f, 0.617742658f, -0.287577927f, 0.821824312f, 0.151177585f, -0.0443819836f, 1.62355745f, -2.32287097f, 1.08783054f, -0.0635453761f, -0.448640704f, -1.27846932f, -1.14400387f, -0.152640373f, 0.116741188f, 0.44026047f, -1.44654655f, -0.558081627f, -0.0516963229f, -0.90827328f, 0.350683212f, -0.394808769f, 0.489227712f, -0.216814891f, -1.74716449f, 1.72284174f, 0.773806036f, 0.404629797f, -1.64612663f, -0.59508425f, -0.711217523f, 0.622964859f, -1.37288189f, -0.128064156f, -1.28383458f, -0.290120065f, 1.27674019f}; +static const std::vector kDeformConvOnnxTest_B = {0.983955026f, 0.204511523f}; +static const std::vector kDeformConvOnnxTest_mask = {-0.0318612382f, -0.478955716f, 0.766808629f, 0.0274681915f, 0.0474699028f, -0.92386651f, -1.06073678f, -2.32444572f, -2.06281757f, 0.00637452863f, -0.989554703f, 0.701609194f, -0.982237995f, 0.277030349f, 0.645495057f, -0.895680785f, 0.492752999f, -0.0140781598f, -0.274662733f, -0.764091492f, -0.58715719f, 1.1951654f, -1.20957518f, -0.556007624f, -0.0771045536f, 1.27737665f, -1.45962942f, -2.15952778f, -0.70670861f, -0.92224431f, 3.89537215f, -0.602696717f}; +static const std::vector kDeformConvOnnxTest_expected_Y = {0.971546292f, 1.1398586f, 0.452816963f, 1.86388242f, -0.565265715f, 1.42318761f, -2.46283293f, -0.104923099f}; diff --git a/onnxruntime/test/testdata/deform_conv_test_data.npz b/onnxruntime/test/testdata/deform_conv_test_data.npz new file mode 100644 index 0000000000000..68639f753501d Binary files /dev/null and b/onnxruntime/test/testdata/deform_conv_test_data.npz differ diff --git a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py new file mode 100644 index 0000000000000..120fb1ed4c211 --- /dev/null +++ b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py @@ -0,0 +1,194 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Generate DeformConv ONNX model and test data for cross-platform validation. + +Based on ONNX DeformConv spec (opset 19+): https://onnx.ai/onnx/operators/onnx__DeformConv.html +Uses a moderately complex config: groups=2, offset_group=2, 2x2 kernel, non-zero offsets. +Reference output from torchvision.ops.deform_conv2d. + +Limitation: Uses symmetric padding only. PyTorch padding=(pad_h, pad_w) and ONNX pads +[pad_top, pad_left, pad_bottom, pad_right] = [pad_h, pad_w, pad_h, pad_w]. Asymmetric +pads (e.g. pad_top != pad_bottom) would require PyTorch API support and are not generated. + +Run from repo root: + python onnxruntime/test/testdata/nn/deform_conv_test_gen.py + +Outputs: + - deform_conv_test.onnx + - deform_conv_test_data.npz (X, W, offset, B, mask, expected_Y) + - deform_conv_test_data.inc (C++ arrays for op test) +""" + +from pathlib import Path + +import numpy as np +from onnx import TensorProto, checker, helper, save + +try: + import onnxruntime as ort +except ImportError: + ort = None +import torch +import torchvision.ops + +# Config: groups=2, offset_group=2, 2x2 kernel (from deform_conv_expected_gen Case 3) +BATCH = 1 +N_IN = 4 +N_OUT = 2 +N_WEIGHT_GRPS = 2 +N_OFFSET_GRPS = 2 +KH, KW = 2, 2 +STRIDE_H, STRIDE_W = 1, 1 +PAD_H, PAD_W = 0, 0 +DIL_H, DIL_W = 1, 1 +IN_H, IN_W = 3, 3 +SEED = 123 + +OUT_H = (IN_H + 2 * PAD_H - (DIL_H * (KH - 1) + 1)) // STRIDE_H + 1 +OUT_W = (IN_W + 2 * PAD_W - (DIL_W * (KW - 1) + 1)) // STRIDE_W + 1 + + +def _generate_reference(): + """Generate inputs and expected output via torchvision.ops.deform_conv2d.""" + torch.manual_seed(SEED) + x = torch.rand(BATCH, N_IN, IN_H, IN_W, dtype=torch.float32) + offset = torch.randn(BATCH, N_OFFSET_GRPS * 2 * KH * KW, OUT_H, OUT_W, dtype=torch.float32) + mask = torch.randn(BATCH, N_OFFSET_GRPS * KH * KW, OUT_H, OUT_W, dtype=torch.float32) + weight = torch.randn(N_OUT, N_IN // N_WEIGHT_GRPS, KH, KW, dtype=torch.float32) + bias = torch.randn(N_OUT, dtype=torch.float32) + + out = torchvision.ops.deform_conv2d( + x, + offset, + weight, + bias=bias, + stride=(STRIDE_H, STRIDE_W), + padding=(PAD_H, PAD_W), + dilation=(DIL_H, DIL_W), + mask=mask, + ) + + return { + "X": x.numpy(), + "W": weight.numpy(), + "offset": offset.numpy(), + "B": bias.numpy(), + "mask": mask.numpy(), + "expected_Y": out.numpy(), + } + + +def _build_onnx_model(): + """Build DeformConv ONNX model. ONNX pads = [pad_top, pad_left, pad_bottom, pad_right].""" + # Symmetric padding only: (pad_h, pad_w) -> [pad_h, pad_w, pad_h, pad_w] + pads = [PAD_H, PAD_W, PAD_H, PAD_W] + + node = helper.make_node( + "DeformConv", + inputs=["X", "W", "offset", "B", "mask"], + outputs=["Y"], + kernel_shape=[KH, KW], + strides=[STRIDE_H, STRIDE_W], + pads=pads, + dilations=[DIL_H, DIL_W], + group=N_WEIGHT_GRPS, + offset_group=N_OFFSET_GRPS, + ) + + graph = helper.make_graph( + [node], + "DeformConvTest", + [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, [BATCH, N_IN, IN_H, IN_W]), + helper.make_tensor_value_info("W", TensorProto.FLOAT, [N_OUT, N_IN // N_WEIGHT_GRPS, KH, KW]), + helper.make_tensor_value_info( + "offset", TensorProto.FLOAT, [BATCH, N_OFFSET_GRPS * 2 * KH * KW, OUT_H, OUT_W] + ), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [N_OUT]), + helper.make_tensor_value_info("mask", TensorProto.FLOAT, [BATCH, N_OFFSET_GRPS * KH * KW, OUT_H, OUT_W]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [BATCH, N_OUT, OUT_H, OUT_W])], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 19)]) + checker.check_model(model) + return model + + +def _to_cpp_array(name: str, arr: np.ndarray) -> str: + """Format numpy array as C++ initializer list.""" + flat = arr.flatten().tolist() + vals = ", ".join(f"{x:.9g}f" for x in flat) + return f"static const std::vector {name} = {{{vals}}};" + + +def _write_cpp_inc(data: dict, inc_path: Path) -> None: + """Write C++ include file with test data.""" + lines = [ + "// Auto-generated by deform_conv_test_gen.py - do not edit", + "", + "#include ", + "", + _to_cpp_array("kDeformConvOnnxTest_X", data["X"]), + _to_cpp_array("kDeformConvOnnxTest_W", data["W"]), + _to_cpp_array("kDeformConvOnnxTest_offset", data["offset"]), + _to_cpp_array("kDeformConvOnnxTest_B", data["B"]), + _to_cpp_array("kDeformConvOnnxTest_mask", data["mask"]), + _to_cpp_array("kDeformConvOnnxTest_expected_Y", data["expected_Y"]), + "", + ] + inc_path.write_text("\n".join(lines), encoding="utf-8") + + +def main(): + # Output to testdata/ root (same as layernorm.onnx, attention_past_state.onnx, etc.) + script_dir = Path(__file__).resolve().parent + assert script_dir.name == "nn", "Script must live in testdata/nn/" + testdata_root = script_dir.parent + model_path = testdata_root / "deform_conv_test.onnx" + data_path = testdata_root / "deform_conv_test_data.npz" + inc_path = testdata_root / "deform_conv_test_data.inc" + + print("Generating reference via torchvision.ops.deform_conv2d...") + data = _generate_reference() + + print("Building ONNX model...") + model = _build_onnx_model() + save(model, str(model_path)) + print(f" Saved {model_path}") + + np.savez(str(data_path), **data) + print(f" Saved {data_path}") + + _write_cpp_inc(data, inc_path) + print(f" Saved {inc_path}") + + # Validate with onnxruntime if available + if ort is not None: + print("Validating with ONNX Runtime...") + sess = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + ort_out = sess.run( + ["Y"], + { + "X": data["X"], + "W": data["W"], + "offset": data["offset"], + "B": data["B"], + "mask": data["mask"], + }, + )[0] + + rtol, atol = 1e-4, 1e-4 + if np.allclose(ort_out, data["expected_Y"], rtol=rtol, atol=atol): + print(" PASS: ORT output matches reference.") + else: + diff = np.abs(ort_out.astype(np.float64) - data["expected_Y"].astype(np.float64)) + print(f" FAIL: max |diff|={diff.max()}, mean={diff.mean()}") + else: + print(" (onnxruntime not installed; skip validation)") + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/unittest_util/base_tester.cc b/onnxruntime/test/unittest_util/base_tester.cc index 18ead92ce3f18..1f744df14cfb8 100644 --- a/onnxruntime/test/unittest_util/base_tester.cc +++ b/onnxruntime/test/unittest_util/base_tester.cc @@ -749,10 +749,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultXnnpackExecutionProvider(); else if (provider_type == onnxruntime::kDmlExecutionProvider) execution_provider = DefaultDmlExecutionProvider(); -#if !defined(USE_WEBGPU) || !defined(ORT_USE_EP_API_ADAPTERS) else if (provider_type == onnxruntime::kWebGpuExecutionProvider) execution_provider = DefaultWebGpuExecutionProvider(); -#endif else if (provider_type == dynamic_plugin_ep_name) { execution_provider = dynamic_plugin_ep_infra::MakeEp(); } diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc index fd2cf2f712628..1f82e1f893eab 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc @@ -11,6 +11,7 @@ #include "nlohmann/json.hpp" #include "core/common/common.h" +#include "core/framework/config_options.h" #include "core/framework/execution_provider.h" #include "core/session/abi_session_options_impl.h" #include "core/session/ort_env.h" @@ -167,7 +168,7 @@ void Shutdown() { g_plugin_ep_infrastructure_state.reset(); } -std::unique_ptr MakeEp(const logging::Logger* logger) { +std::unique_ptr MakeEp(const logging::Logger* logger, const ConfigOptions* ep_options) { if (!IsInitialized()) { return nullptr; } @@ -182,6 +183,13 @@ std::unique_ptr MakeEp(const logging::Logger* logger) { StrMapToKeyValueCstrVectors(state.config.default_ep_options, default_ep_option_key_cstrs, default_ep_option_value_cstrs); + if (ep_options != nullptr) { + for (const auto& [key, value] : ep_options->configurations) { + default_ep_option_key_cstrs.push_back(key.c_str()); + default_ep_option_value_cstrs.push_back(value.c_str()); + } + } + OrtSessionOptions ort_session_options{}; ORT_THROW_IF_ERROR(AddEpOptionsToSessionOptions(state.selected_c_ep_devices, default_ep_option_key_cstrs, diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h index 680045be9330c..0962df8e35308 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h @@ -17,6 +17,7 @@ namespace onnxruntime { struct IExecutionProviderFactory; class IExecutionProvider; +struct ConfigOptions; namespace logging { class Logger; @@ -74,7 +75,8 @@ bool IsInitialized(); void Shutdown(); // Returns a dynamic plugin EP `IExecutionProvider` instance, or `nullptr` if uninitialized. -std::unique_ptr MakeEp(const logging::Logger* logger = nullptr); +// `ep_options` provides additional EP-specific option overrides (key-value pairs) on top of the defaults. +std::unique_ptr MakeEp(const logging::Logger* logger = nullptr, const ConfigOptions* ep_options = nullptr); // Gets the dynamic plugin EP name, or `std::nullopt` if uninitialized. std::optional GetEpName(); diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 6dc38f84c79d5..7e6bc6ae06020 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -14,8 +14,13 @@ #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" #endif +#if defined(USE_WEBGPU) +#include "core/graph/constants.h" +#include "core/session/abi_session_options_impl.h" +#endif #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/providers.h" +#include "test/unittest_util/test_dynamic_plugin_ep.h" namespace onnxruntime { @@ -273,19 +278,37 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { } std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) { -#if defined(USE_WEBGPU) && !defined(ORT_USE_EP_API_ADAPTERS) +#if defined(USE_WEBGPU) ConfigOptions config_options{}; + + // Helper to strip the EP prefix from config entry keys when building as a plugin EP. + // The full key is like "ep.webgpuexecutionprovider.storageBufferCacheMode", and the + // config entry expects just "storageBufferCacheMode" in the EP API build. + // Returns a pointer into the original string, so the result is valid as long as the input is. + auto strip_ep_prefix = [](const char* full_key) -> const char* { +#if defined(ORT_USE_EP_API_ADAPTERS) + std::string_view key{full_key}; + std::string_view prefix = OrtSessionOptions::GetProviderOptionPrefix(kWebGpuExecutionProvider); + ORT_ENFORCE(key.length() >= prefix.length() && key.substr(0, prefix.length()) == prefix, + "Config key \"", key, "\" does not start with expected prefix \"", prefix, "\""); + return full_key + prefix.length(); +#else + return full_key; +#endif + }; + // Disable storage buffer cache - ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode, + ORT_ENFORCE(config_options.AddConfigEntry(strip_ep_prefix(webgpu::options::kStorageBufferCacheMode), webgpu::options::kBufferCacheMode_Disabled) .IsOK()); if (!is_nhwc) { // Enable NCHW support - ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kPreferredLayout, + ORT_ENFORCE(config_options.AddConfigEntry(strip_ep_prefix(webgpu::options::kPreferredLayout), webgpu::options::kPreferredLayout_NCHW) .IsOK()); } - return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); + + return WebGpuExecutionProviderWithOptions(config_options); #else ORT_UNUSED_PARAMETER(is_nhwc); return nullptr; @@ -293,8 +316,16 @@ std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) } std::unique_ptr WebGpuExecutionProviderWithOptions(const ConfigOptions& config_options) { -#if defined(USE_WEBGPU) && !defined(ORT_USE_EP_API_ADAPTERS) +#if defined(USE_WEBGPU) +#if defined(ORT_USE_EP_API_ADAPTERS) + auto ep_name = dynamic_plugin_ep_infra::GetEpName(); + ORT_ENFORCE(ep_name == kWebGpuExecutionProvider, + "Dynamic plugin EP is not the WebGPU EP. Expected \"", kWebGpuExecutionProvider, + "\", got \"", ep_name.value_or(""), "\""); + return dynamic_plugin_ep_infra::MakeEp(nullptr, &config_options); +#else return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); +#endif #else ORT_UNUSED_PARAMETER(config_options); return nullptr; diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index f767ef110561a..88d2981e2ccaa 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -194,6 +194,7 @@ stages: - template: ../templates/py-linux-qnn.yml parameters: machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' + QnnSdk: ${{ parameters.qnn_sdk_version }} extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} is1ES: true