Skip to content

Commit 6bf490a

Browse files
authored
Merge pull request #989 from intel/sync_msft_25032026
Sync with Microsoft ONNX Runtime - 25032026
2 parents 4846b1b + 6af93e8 commit 6bf490a

102 files changed

Lines changed: 7080 additions & 512 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/windows_webgpu.yml

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,86 @@ jobs:
155155
working-directory: ${{ github.workspace }}\csharp
156156
continue-on-error: true
157157

158+
webgpu_plugin_build_x64_RelWithDebInfo:
159+
runs-on: [
160+
"self-hosted",
161+
"1ES.Pool=onnxruntime-github-Win2022-GPU-A10",
162+
"JobId=webgpu_plugin_build_x64_RelWithDebInfo-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}"
163+
]
164+
timeout-minutes: 300
165+
env:
166+
OnnxRuntimeBuildDirectory: ${{ github.workspace }}
167+
setVcvars: true
168+
ALLOW_RELEASED_ONNX_OPSET_ONLY: "0"
169+
DocUpdateNeeded: false
170+
NVIDIA_TF32_OVERRIDE: "0"
171+
ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0"
172+
steps:
173+
- name: Checkout
174+
uses: actions/checkout@v6
175+
with:
176+
fetch-depth: 0
177+
submodules: none
178+
179+
- name: Setup Python 3.12
180+
uses: actions/setup-python@v6
181+
with:
182+
python-version: "3.12"
183+
architecture: x64
184+
185+
- name: Locate vcvarsall and Setup Env
186+
uses: ./.github/actions/locate-vcvarsall-and-setup-env
187+
with:
188+
architecture: x64
189+
190+
- name: Install python modules
191+
run: python -m pip install -r tools\ci_build\github\windows\python\requirements.txt
192+
shell: cmd
193+
working-directory: ${{ github.workspace }}
194+
195+
- name: Setup Node.js
196+
uses: actions/setup-node@v6
197+
with:
198+
node-version: "20.x"
199+
200+
- uses: actions/cache@v5
201+
id: onnx-node-tests-cache
202+
with:
203+
path: ${{ github.workspace }}/js/test/
204+
key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }}
205+
206+
- name: Build and Test
207+
shell: pwsh
208+
run: |
209+
python.exe ${{ github.workspace }}\tools\ci_build\build.py `
210+
--config RelWithDebInfo `
211+
--build_dir ${{ github.workspace }} `
212+
--skip_submodule_sync `
213+
--parallel `
214+
--use_binskim_compliant_compile_flags `
215+
--cmake_generator "Visual Studio 17 2022" `
216+
--enable_onnx_tests `
217+
--use_webgpu shared_lib `
218+
--wgsl_template static `
219+
--use_vcpkg --use_vcpkg_ms_internal_asset_cache `
220+
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_DAWN_BACKEND_D3D12=1 onnxruntime_ENABLE_DAWN_BACKEND_VULKAN=1 `
221+
--disable_rtti `
222+
--enable_lto
223+
224+
if ($lastExitCode -ne 0) {
225+
exit $lastExitCode
226+
}
227+
228+
- name: Publish artifacts
229+
uses: actions/upload-artifact@v4
230+
with:
231+
name: webgpu-plugin-binaries
232+
path: |
233+
${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_webgpu.dll
234+
${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_webgpu.pdb
235+
${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/dxcompiler.dll
236+
${{ github.workspace }}/RelWithDebInfo/RelWithDebInfo/dxil.dll
237+
158238
webgpu_external_dawn_build_x64_RelWithDebInfo:
159239
runs-on: [
160240
"self-hosted",

cmake/onnxruntime_mlas.cmake

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
3434
${MLAS_SRC_DIR}/eltwise.h
3535
${MLAS_SRC_DIR}/eltwise.cpp
3636
${MLAS_SRC_DIR}/erf.cpp
37+
${MLAS_SRC_DIR}/silu.cpp
38+
${MLAS_SRC_DIR}/gelu.cpp
3739
${MLAS_SRC_DIR}/compute.cpp
3840
${MLAS_SRC_DIR}/dequantize.cpp
3941
${MLAS_SRC_DIR}/quantize.cpp
@@ -201,6 +203,14 @@ function(setup_mlas_source_for_windows)
201203
)
202204
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2")
203205

206+
set(mlas_platform_srcs_avx512
207+
${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp
208+
${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp
209+
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
210+
)
211+
212+
set_source_files_properties(${mlas_platform_srcs_avx512} PROPERTIES COMPILE_FLAGS "/arch:AVX512")
213+
204214
target_sources(onnxruntime_mlas PRIVATE
205215
${MLAS_SRC_DIR}/dgemm.cpp
206216
${mlas_platform_srcs_avx}
@@ -212,7 +222,7 @@ function(setup_mlas_source_for_windows)
212222
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
213223
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
214224
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
215-
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
225+
${mlas_platform_srcs_avx512}
216226
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h
217227
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp
218228
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
@@ -764,6 +774,8 @@ endif()
764774
${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S
765775
${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S
766776
${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S
777+
${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp
778+
${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp
767779
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
768780
)
769781
set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f")

cmake/onnxruntime_providers_webgpu.cmake

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
endif()
5757
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs})
5858

59-
onnxruntime_add_shared_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
59+
onnxruntime_add_shared_library_module(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
6060
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu
6161
${REPO_ROOT}/include/onnxruntime/core/session
6262
onnxruntime_common
@@ -119,6 +119,12 @@
119119
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
120120
message(FATAL_ERROR "WebGPU EP shared library build is not supported on Emscripten. Please use static library build.")
121121
endif()
122+
123+
# Configure precompiled headers for shared library build
124+
# PCH ensures ep/adapters.h is included first and improves compilation speed
125+
target_precompile_headers(onnxruntime_providers_webgpu PRIVATE
126+
"${REPO_ROOT}/include/onnxruntime/ep/adapters.h"
127+
)
122128
endif()
123129
124130
set_target_properties(onnxruntime_providers_webgpu PROPERTIES CXX_STANDARD_REQUIRED ON)

cmake/onnxruntime_unittests.cmake

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,18 @@ function(onnxruntime_apply_test_target_workarounds target)
10421042
endif()
10431043
endfunction()
10441044

1045+
# Set environment variables for plugin EP tests when run via CTest.
1046+
function(onnxruntime_set_plugin_ep_test_environment target)
1047+
if(onnxruntime_USE_WEBGPU AND onnxruntime_USE_EP_API_ADAPTERS)
1048+
set(ORT_PLUGIN_EP_JSON_CONFIG "{\"ep_library_registration_name\": \"WebGPU_PluginEP\", \"ep_library_path\": \"$<TARGET_FILE_NAME:onnxruntime_providers_webgpu>\", \"selected_ep_name\": \"WebGpuExecutionProvider\"}")
1049+
set_tests_properties(${target} PROPERTIES
1050+
ENVIRONMENT "ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON=${ORT_PLUGIN_EP_JSON_CONFIG}"
1051+
)
1052+
# TODO: add for other plugin EPs if needed
1053+
# elseif()
1054+
endif()
1055+
endfunction()
1056+
10451057
function(onnxruntime_apply_emscripten_test_link_settings target)
10461058
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
10471059
set_target_properties(${target} PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js)
@@ -1250,6 +1262,7 @@ block()
12501262
)
12511263

12521264
onnxruntime_apply_test_target_workarounds(onnxruntime_provider_test)
1265+
onnxruntime_set_plugin_ep_test_environment(onnxruntime_provider_test)
12531266

12541267
# Expose QNN SDK headers to unit tests via an interface target
12551268
if(onnxruntime_USE_QNN)

docs/OperatorKernels.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ Do not modify directly.*
103103
|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int32), tensor(int64)|
104104
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *in* axis:**tensor(int64)**<br> *out* output:**T1**<br><br>or<br><br>*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|20+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
105105
|||[17, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
106+
|DeformConv|*in* X:**T**<br> *in* W:**T**<br> *in* offset:**T**<br> *in* B:**T**<br> *in* mask:**T**<br> *out* Y:**T**|22+|**T** = tensor(double), tensor(float)|
107+
|||[19, 21]|**T** = tensor(double), tensor(float)|
106108
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(uint8)|
107109
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(uint8)|
108110
|||[1, 10]|**T** = tensor(double), tensor(float)|
@@ -697,6 +699,8 @@ Do not modify directly.*
697699
|Crop|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
698700
|CumSum|*in* x:**T**<br> *in* axis:**T2**<br> *out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(int64)|
699701
|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(int64)|
702+
|DeformConv|*in* X:**T**<br> *in* W:**T**<br> *in* offset:**T**<br> *in* B:**T**<br> *in* mask:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
703+
|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
700704
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
701705
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
702706
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
@@ -843,7 +847,12 @@ Do not modify directly.*
843847
|PRelu|*in* X:**T**<br> *in* slope:**T**<br> *out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)|
844848
|||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)|
845849
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
846-
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *in* axes:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
850+
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *in* axes:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|25+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
851+
|||24|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
852+
|||23|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
853+
|||[21, 22]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
854+
|||[19, 20]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
855+
|||18|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
847856
|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
848857
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
849858
|||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
@@ -902,7 +911,9 @@ Do not modify directly.*
902911
|||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
903912
|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
904913
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
905-
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
914+
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|22+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(int64)|
915+
|||[16, 21]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(int64)|
916+
|||[10, 15]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
906917
|RotaryEmbedding|*in* X:**T**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**M**<br> *out* Y:**T**|23+|**M** = tensor(int64)<br/> **T** = tensor(bfloat16), tensor(float), tensor(float16)|
907918
|Round|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
908919
|ScaledTanh|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,13 @@ static const char* const kOrtSessionOptionsMlasDisableKleidiAi = "mlas.disable_k
391391
// If not provided, default is 4.
392392
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";
393393

394+
// Block size used when converting per-tensor or per-axis DQ + MatMul to MatMulNBits.
395+
// Only applies to DQ nodes without an existing block_size attribute (i.e., per-tensor or per-axis quantization).
396+
// Positive value: explicit block_size (must be power-of-2 and >= 16, e.g., 16, 32, 64, 128).
397+
// "0" or not provided: use default block_size of 32.
398+
// "-1": heuristic - largest power-of-2 <= min(K, 256) that minimizes padding.
399+
static const char* const kOrtSessionOptionsQDQMatMulNBitsBlockSize = "session.qdq_matmulnbits_block_size";
400+
394401
// Enable the DQ->MatMulNBits fusion graph transformer.
395402
// "0": disabled (default). "1": enabled.
396403
// This is typically set automatically by InferenceSession when the NvTensorRTRTX EP is registered.

include/onnxruntime/ep/adapter/allocator.h

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,50 @@ namespace adapter {
1818
/// </summary>
1919
class Allocator : public OrtAllocator {
2020
public:
21+
/**
22+
* Create from an existing AllocatorPtr.
23+
*/
2124
explicit Allocator(const OrtMemoryInfo* memory_info, AllocatorPtr impl)
22-
: OrtAllocator{}, memory_info_(memory_info), impl_(impl) {
25+
: Allocator{memory_info} {
26+
ORT_ENFORCE(impl != nullptr, "Allocator implementation cannot be null.");
27+
impl_ = impl;
28+
}
29+
30+
using AllocatorFactory = AllocatorPtr (*)(const OrtMemoryInfo& memory_info);
31+
32+
/**
33+
* Create from an AllocatorFactory, which will be called lazily when the first allocation is made.
34+
*/
35+
explicit Allocator(const OrtMemoryInfo* memory_info, AllocatorFactory get_allocator_impl)
36+
: Allocator{memory_info} {
37+
get_allocator_impl_ = get_allocator_impl;
38+
}
39+
40+
private:
41+
explicit Allocator(const OrtMemoryInfo* memory_info)
42+
: OrtAllocator{}, memory_info_(memory_info) {
2343
version = ORT_API_VERSION;
2444
Alloc = AllocImpl;
2545
Free = FreeImpl;
2646
Info = InfoImpl;
2747
}
48+
AllocatorPtr GetImpl() {
49+
if (!impl_) {
50+
std::call_once(init_flag_, [this]() {
51+
impl_ = get_allocator_impl_(*memory_info_);
52+
});
53+
}
54+
return impl_;
55+
}
2856

29-
private:
3057
static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept {
3158
auto* allocator = static_cast<Allocator*>(this_ptr);
32-
return allocator->impl_->Alloc(size);
59+
return allocator->GetImpl()->Alloc(size);
3360
}
3461

3562
static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept {
3663
auto* allocator = static_cast<Allocator*>(this_ptr);
37-
allocator->impl_->Free(p);
64+
allocator->GetImpl()->Free(p);
3865
}
3966

4067
static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept {
@@ -44,6 +71,8 @@ class Allocator : public OrtAllocator {
4471

4572
const OrtMemoryInfo* memory_info_;
4673
AllocatorPtr impl_;
74+
AllocatorFactory get_allocator_impl_;
75+
std::once_flag init_flag_;
4776
};
4877

4978
} // namespace adapter

include/onnxruntime/ep/adapter/ep.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class Ep : public OrtEp {
2727
profiler_{impl_->GetProfiler()},
2828
temp_space_cpu_allocator_{temp_space_cpu_allocator},
2929
temp_space_allocator_{temp_space_allocator} {
30+
ort_version_supported = ORT_API_VERSION;
3031
}
3132

3233
public:

include/onnxruntime/ep/adapter/op_kernel_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "tensor_helper.h"
2020

2121
namespace onnxruntime {
22-
struct DataTransferManager;
22+
class DataTransferManager;
2323
struct IExecutionProvider;
2424
} // namespace onnxruntime
2525

include/onnxruntime/ep/common.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,24 @@
4141
OrtStatus* _status = (status_expr); \
4242
Ort::Status _ignored{_status}; \
4343
} while (false)
44+
45+
// Helper macros to convert exceptions to OrtStatus* return values.
46+
// Usage:
47+
// EXCEPTION_TO_RETURNED_STATUS_BEGIN
48+
// ... code that may throw ...
49+
// EXCEPTION_TO_RETURNED_STATUS_END
50+
#define EXCEPTION_TO_RETURNED_STATUS_BEGIN try {
51+
#define EXCEPTION_TO_RETURNED_STATUS_END \
52+
} \
53+
catch (const Ort::Exception& ex) { \
54+
Ort::Status status(ex); \
55+
return status.release(); \
56+
} \
57+
catch (const std::exception& ex) { \
58+
Ort::Status status(ex.what(), ORT_EP_FAIL); \
59+
return status.release(); \
60+
} \
61+
catch (...) { \
62+
Ort::Status status("Unknown exception", ORT_EP_FAIL); \
63+
return status.release(); \
64+
}

0 commit comments

Comments
 (0)