Skip to content

Commit 978efc8

Browse files
authored
[Plugin TRT EP] Add MemcpyToHost and MemcpyFromHost kernel implementations (#557)
* update * refactor MemcpyKernelBase * add define ORT_API_MANUAL_INIT * address reviewer's comments * address reviewr's comments
1 parent bd104a0 commit 978efc8

15 files changed

Lines changed: 506 additions & 24 deletions

plugin_execution_providers/tensorrt/CMakeLists.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,15 @@ endif()
3030
add_definitions(-DONNX_NAMESPACE=onnx)
3131
add_definitions(-DONNX_ML)
3232
add_definitions(-DNOMINMAX)
33-
file(GLOB tensorrt_src "./src/*.cc" "./src/utils/*.cc" "./src/cuda/unary_elementwise_ops_impl.cu" "./src/*.h")
33+
34+
file(GLOB tensorrt_src
35+
"./src/*.cc"
36+
"./src/kernels/*.cc"
37+
"./src/utils/*.cc"
38+
"./src/cuda/unary_elementwise_ops_impl.cu"
39+
"./src/*.h"
40+
"./src/kernels/*.h"
41+
)
3442
add_library(TensorRTEp SHARED ${tensorrt_src})
3543

3644
set_onnxruntime_paths(
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "utils.h"
5+
#include "memcpy.h"
6+
#include <cuda_runtime.h>
7+
8+
namespace trt_ep {
9+
10+
template <typename T>
11+
OrtStatus* MemcpyKernelBase::CreateImpl(const OrtKernelInfo* info, void* state,
12+
/*out*/ OrtKernelImpl*& kernel) noexcept {
13+
try {
14+
auto p = std::make_unique<T>(info, state, typename T::PrivateTag{});
15+
kernel = p.release();
16+
return nullptr;
17+
} catch (const Ort::Exception& ex) {
18+
Ort::Status status(ex);
19+
return status.release();
20+
} catch (const std::exception& ex) {
21+
Ort::Status status(ex.what(), ORT_EP_FAIL);
22+
return status.release();
23+
} catch (...) {
24+
Ort::Status status("Unknown exception in MemcpyKernelBase::Create", ORT_EP_FAIL);
25+
return status.release();
26+
}
27+
}
28+
29+
template <typename T>
30+
static void MemcpyKernelBase::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept {
31+
delete static_cast<T*>(this_ptr);
32+
}
33+
34+
OrtStatus* MemcpyFromHost::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept {
35+
try {
36+
const OrtApi& ort_api = Ort::GetApi();
37+
const OrtValue* input_tensor = nullptr;
38+
RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_ctx, 0, &input_tensor));
39+
40+
// Get tensor shape and type
41+
OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
42+
RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input_tensor, &tensor_info));
43+
44+
size_t element_count = 0;
45+
RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(tensor_info, &element_count));
46+
47+
ONNXTensorElementDataType element_type;
48+
RETURN_IF_ERROR(ort_api.GetTensorElementType(tensor_info, &element_type));
49+
50+
size_t num_dims = 0;
51+
RETURN_IF_ERROR(ort_api.GetDimensionsCount(tensor_info, &num_dims));
52+
53+
std::vector<int64_t> dims(num_dims);
54+
RETURN_IF_ERROR(ort_api.GetDimensions(tensor_info, dims.data(), num_dims));
55+
ort_api.ReleaseTensorTypeAndShapeInfo(tensor_info);
56+
57+
// Get output tensor
58+
OrtValue* output_tensor = nullptr;
59+
RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_ctx, 0, dims.data(), num_dims, &output_tensor));
60+
61+
// Get data pointers
62+
const void* input_data = nullptr;
63+
void* output_data = nullptr;
64+
RETURN_IF_ERROR(ort_api.GetTensorData(input_tensor, &input_data));
65+
RETURN_IF_ERROR(ort_api.GetTensorMutableData(output_tensor, &output_data));
66+
67+
// Calculate size in bytes
68+
size_t bytes = 0;
69+
RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(input_tensor, &bytes));
70+
71+
// Get CUDA stream from kernel context
72+
void* cuda_stream = nullptr;
73+
RETURN_IF_ERROR(ort_api.KernelContext_GetGPUComputeStream(kernel_ctx, &cuda_stream));
74+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream);
75+
76+
// Copy from host (CPU) to device (GPU) asynchronously
77+
cudaError_t cuda_err = cudaMemcpyAsync(output_data, input_data, bytes, cudaMemcpyHostToDevice, stream);
78+
if (cuda_err != cudaSuccess) {
79+
return ort_api.CreateStatus(ORT_EP_FAIL, cudaGetErrorString(cuda_err));
80+
}
81+
82+
return nullptr;
83+
} catch (const Ort::Exception& ex) {
84+
Ort::Status status(ex);
85+
return status.release();
86+
} catch (const std::exception& ex) {
87+
Ort::Status status(ex.what(), ORT_EP_FAIL);
88+
return status.release();
89+
}
90+
}
91+
92+
OrtStatus* MemcpyToHost::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept {
93+
try {
94+
const OrtApi& ort_api = Ort::GetApi();
95+
const OrtValue* input_tensor = nullptr;
96+
RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_ctx, 0, &input_tensor));
97+
98+
// Get tensor shape and type
99+
OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
100+
RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input_tensor, &tensor_info));
101+
102+
size_t num_dims = 0;
103+
RETURN_IF_ERROR(ort_api.GetDimensionsCount(tensor_info, &num_dims));
104+
105+
std::vector<int64_t> dims(num_dims);
106+
RETURN_IF_ERROR(ort_api.GetDimensions(tensor_info, dims.data(), num_dims));
107+
ort_api.ReleaseTensorTypeAndShapeInfo(tensor_info);
108+
109+
// Get output tensor
110+
OrtValue* output_tensor = nullptr;
111+
RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_ctx, 0, dims.data(), num_dims, &output_tensor));
112+
113+
// Get data pointers
114+
const void* input_data = nullptr;
115+
void* output_data = nullptr;
116+
RETURN_IF_ERROR(ort_api.GetTensorData(input_tensor, &input_data));
117+
RETURN_IF_ERROR(ort_api.GetTensorMutableData(output_tensor, &output_data));
118+
119+
// Calculate size in bytes
120+
size_t bytes = 0;
121+
RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(input_tensor, &bytes));
122+
123+
// Get CUDA stream from kernel context
124+
void* cuda_stream = nullptr;
125+
RETURN_IF_ERROR(ort_api.KernelContext_GetGPUComputeStream(kernel_ctx, &cuda_stream));
126+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream);
127+
128+
// Copy from device (GPU) to host (CPU) asynchronously
129+
cudaError_t cuda_err = cudaMemcpyAsync(output_data, input_data, bytes, cudaMemcpyDeviceToHost, stream);
130+
if (cuda_err != cudaSuccess) {
131+
return ort_api.CreateStatus(ORT_EP_FAIL, cudaGetErrorString(cuda_err));
132+
}
133+
134+
return nullptr;
135+
} catch (const Ort::Exception& ex) {
136+
Ort::Status status(ex);
137+
return status.release();
138+
} catch (const std::exception& ex) {
139+
Ort::Status status(ex.what(), ORT_EP_FAIL);
140+
return status.release();
141+
}
142+
}
143+
144+
ONNX_OPERATOR_KERNEL_EX(
145+
MemcpyFromHost,
146+
kOnnxDomain,
147+
/*version*/ 1, // Equivalent to start_version: 14, end_version: 14 (inclusive)
148+
(Ort::KernelDefBuilder()
149+
.SetInputMemType(0, OrtMemType::OrtMemTypeCPUInput)
150+
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))),
151+
MemcpyFromHost)
152+
153+
ONNX_OPERATOR_KERNEL_EX(
154+
MemcpyToHost,
155+
kOnnxDomain,
156+
/*version*/ 1, // Equivalent to start_version: 14, end_version: 14 (inclusive)
157+
(Ort::KernelDefBuilder()
158+
.SetOutputMemType(0, OrtMemType::OrtMemTypeCPUOutput)
159+
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))),
160+
MemcpyToHost)
161+
162+
} // namespace trt_ep
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
struct OrtKernelImpl;
5+
struct OrtKernelInfo;
6+
struct OrtKernelContext;
7+
struct OrtStatus;
8+
9+
namespace trt_ep {
10+
11+
struct MemcpyKernelBase : public OrtKernelImpl {
12+
// Base class for MemcpyFromHost and MemcpyToHost to share common code.
13+
protected:
14+
MemcpyKernelBase(const OrtKernelInfo* info, void* state) : OrtKernelImpl {}, info_(info), state_(state) {}
15+
16+
template <typename T>
17+
static OrtStatus* CreateImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& kernel) noexcept;
18+
19+
template <typename T>
20+
static void ReleaseImpl(OrtKernelImpl* this_ptr) noexcept;
21+
22+
const OrtKernelInfo* info_;
23+
void* state_; // Custom state passed from OrtEp
24+
};
25+
26+
struct MemcpyFromHost : public MemcpyKernelBase {
27+
private:
28+
struct PrivateTag {}; // Used to prevent use of public constructor (use static MemcpyFromHost::Create())
29+
// Need to make the constructor public for std::make_unique().
30+
31+
// Allow base template helper to access PrivateTag
32+
friend struct MemcpyKernelBase;
33+
34+
public:
35+
MemcpyFromHost(const OrtKernelInfo* info, void* state, PrivateTag) : MemcpyKernelBase(info, state) {
36+
ort_version_supported = ORT_API_VERSION;
37+
Compute = ComputeImpl;
38+
Release = ReleaseImpl;
39+
};
40+
41+
static OrtStatus* Create(const OrtKernelInfo* info, void* state,
42+
/*out*/ OrtKernelImpl*& kernel) noexcept {
43+
return CreateImpl<MemcpyFromHost>(info, state, kernel);
44+
}
45+
46+
static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept;
47+
48+
static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept {
49+
MemcpyKernelBase::ReleaseImpl<MemcpyFromHost>(this_ptr);
50+
};
51+
};
52+
53+
struct MemcpyToHost : public MemcpyKernelBase {
54+
private:
55+
struct PrivateTag {}; // Used to prevent use of public constructor (use static MemcpyFromHost::Create())
56+
// Need to make the constructor public for std::make_unique().
57+
58+
// Allow base template helper to access PrivateTag
59+
friend struct MemcpyKernelBase;
60+
61+
public:
62+
MemcpyToHost(const OrtKernelInfo* info, void* state, PrivateTag) : MemcpyKernelBase(info, state) {
63+
ort_version_supported = ORT_API_VERSION;
64+
Compute = ComputeImpl;
65+
Release = ReleaseImpl;
66+
};
67+
68+
static OrtStatus* Create(const OrtKernelInfo* info, void* state,
69+
/*out*/ OrtKernelImpl*& kernel) noexcept {
70+
return CreateImpl<MemcpyToHost>(info, state, kernel);
71+
}
72+
73+
static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept;
74+
75+
static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept {
76+
MemcpyKernelBase::ReleaseImpl<MemcpyToHost>(this_ptr);
77+
};
78+
};
79+
80+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#pragma once
2+
3+
#include "ep_utils.h"
4+
5+
namespace trt_ep {
6+
7+
/// <summary>
8+
/// Gets an OrtDataType for a tensor type. Throws on error.
9+
/// </summary>
10+
/// <param name="elem_type"></param>
11+
/// <returns></returns>
12+
inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) {
13+
const OrtEpApi& ep_api = Ort::GetEpApi();
14+
const OrtDataType* result = nullptr;
15+
16+
Ort::ThrowOnError(ep_api.GetTensorDataType(elem_type, &result));
17+
return result;
18+
}
19+
20+
/// <summary>
21+
/// Contains information to create a kernel: kernel definition, creation function + state.
22+
/// </summary>
23+
struct KernelCreateInfo {
24+
KernelCreateInfo() = default;
25+
KernelCreateInfo(Ort::KernelDef def, OrtKernelCreateFunc func, void* state)
26+
: kernel_def{std::move(def)}, kernel_create_func{func}, kernel_create_func_state{state} {}
27+
28+
Ort::KernelDef kernel_def{nullptr};
29+
OrtKernelCreateFunc kernel_create_func = nullptr;
30+
void* kernel_create_func_state = nullptr;
31+
};
32+
33+
using BuildKernelCreateInfoFn = OrtStatus* (*)(const char*, void*, KernelCreateInfo*);
34+
35+
template <typename T>
36+
OrtStatus* BuildKernelCreateInfo(const char* ep_name, void* create_func_state, /*out*/ KernelCreateInfo* result);
37+
38+
template <>
39+
inline OrtStatus* BuildKernelCreateInfo<void>(const char* /*ep_name*/, void* /*create_func_state*/,
40+
/*out*/ KernelCreateInfo* result) {
41+
result->kernel_def = Ort::KernelDef{nullptr};
42+
result->kernel_create_func = nullptr;
43+
result->kernel_create_func_state = nullptr;
44+
return nullptr;
45+
}
46+
47+
static constexpr const char* kOnnxDomain = "";
48+
49+
// Naming convention for operator kernel classes with a start and end version range.
50+
#define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, startver, endver, name) \
51+
example_ep_##name##_##domain##_ver##startver##_##endver
52+
53+
// Naming convention for operator kernel classes for a single version
54+
#define ONNX_OPERATOR_KERNEL_CLASS_NAME(domain, version, name) \
55+
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, version, version, name)
56+
57+
// Defines a function of type BuildKernelCreateInfoFn for a kernel implementation with a start and end version range.
58+
#define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, builder, kernel_class) \
59+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, startver, endver, name); \
60+
template <> \
61+
OrtStatus* \
62+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, startver, endver, name)>( \
63+
const char* ep_name, \
64+
void* create_kernel_state, \
65+
KernelCreateInfo* result) { \
66+
try { \
67+
Ort::KernelDef kernel_def = builder.SetOperatorType(#name) \
68+
.SetDomain(domain) \
69+
.SetSinceVersion(startver, endver) \
70+
.SetExecutionProvider(ep_name) \
71+
.Build(); \
72+
\
73+
auto kernel_create_func = [](void* state, const OrtKernelInfo* info, \
74+
OrtKernelImpl** kernel_out) noexcept -> OrtStatus* { \
75+
RETURN_IF(kernel_out == nullptr, \
76+
"OrtKernelCreateFunc received a NULL kernel_out argument"); \
77+
\
78+
*kernel_out = nullptr; \
79+
RETURN_IF_ERROR(kernel_class::Create(info, state, *kernel_out)); \
80+
return nullptr; \
81+
}; \
82+
\
83+
*result = KernelCreateInfo(std::move(kernel_def), kernel_create_func, create_kernel_state); \
84+
} catch (const Ort::Exception& ex) { \
85+
Ort::Status status(ex); \
86+
return status.release(); \
87+
} catch (const std::exception& ex) { \
88+
Ort::Status status(ex.what(), ORT_EP_FAIL); \
89+
return status.release(); \
90+
} \
91+
return nullptr; \
92+
}
93+
94+
// Defines a function of type BuildKernelCreateInfoFn for a kernel implementation with a start version.
95+
#define ONNX_OPERATOR_KERNEL_EX(name, domain, version, builder, kernel_class) \
96+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, version, version, builder, kernel_class)
97+
}

0 commit comments

Comments
 (0)