1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #include " utils.h"
5+ #include " memcpy.h"
6+ #include < cuda_runtime.h>
7+
8+ namespace trt_ep {
9+
10+ template <typename T>
11+ OrtStatus* MemcpyKernelBase::CreateImpl (const OrtKernelInfo* info, void * state,
12+ /* out*/ OrtKernelImpl*& kernel) noexcept {
13+ try {
14+ auto p = std::make_unique<T>(info, state, typename T::PrivateTag{});
15+ kernel = p.release ();
16+ return nullptr ;
17+ } catch (const Ort::Exception& ex) {
18+ Ort::Status status (ex);
19+ return status.release ();
20+ } catch (const std::exception& ex) {
21+ Ort::Status status (ex.what (), ORT_EP_FAIL);
22+ return status.release ();
23+ } catch (...) {
24+ Ort::Status status (" Unknown exception in MemcpyKernelBase::Create" , ORT_EP_FAIL);
25+ return status.release ();
26+ }
27+ }
28+
29+ template <typename T>
30+ static void MemcpyKernelBase::ReleaseImpl (OrtKernelImpl* this_ptr) noexcept {
31+ delete static_cast <T*>(this_ptr);
32+ }
33+
34+ OrtStatus* MemcpyFromHost::ComputeImpl (OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept {
35+ try {
36+ const OrtApi& ort_api = Ort::GetApi ();
37+ const OrtValue* input_tensor = nullptr ;
38+ RETURN_IF_ERROR (ort_api.KernelContext_GetInput (kernel_ctx, 0 , &input_tensor));
39+
40+ // Get tensor shape and type
41+ OrtTensorTypeAndShapeInfo* tensor_info = nullptr ;
42+ RETURN_IF_ERROR (ort_api.GetTensorTypeAndShape (input_tensor, &tensor_info));
43+
44+ size_t element_count = 0 ;
45+ RETURN_IF_ERROR (ort_api.GetTensorShapeElementCount (tensor_info, &element_count));
46+
47+ ONNXTensorElementDataType element_type;
48+ RETURN_IF_ERROR (ort_api.GetTensorElementType (tensor_info, &element_type));
49+
50+ size_t num_dims = 0 ;
51+ RETURN_IF_ERROR (ort_api.GetDimensionsCount (tensor_info, &num_dims));
52+
53+ std::vector<int64_t > dims (num_dims);
54+ RETURN_IF_ERROR (ort_api.GetDimensions (tensor_info, dims.data (), num_dims));
55+ ort_api.ReleaseTensorTypeAndShapeInfo (tensor_info);
56+
57+ // Get output tensor
58+ OrtValue* output_tensor = nullptr ;
59+ RETURN_IF_ERROR (ort_api.KernelContext_GetOutput (kernel_ctx, 0 , dims.data (), num_dims, &output_tensor));
60+
61+ // Get data pointers
62+ const void * input_data = nullptr ;
63+ void * output_data = nullptr ;
64+ RETURN_IF_ERROR (ort_api.GetTensorData (input_tensor, &input_data));
65+ RETURN_IF_ERROR (ort_api.GetTensorMutableData (output_tensor, &output_data));
66+
67+ // Calculate size in bytes
68+ size_t bytes = 0 ;
69+ RETURN_IF_ERROR (ort_api.GetTensorSizeInBytes (input_tensor, &bytes));
70+
71+ // Get CUDA stream from kernel context
72+ void * cuda_stream = nullptr ;
73+ RETURN_IF_ERROR (ort_api.KernelContext_GetGPUComputeStream (kernel_ctx, &cuda_stream));
74+ cudaStream_t stream = static_cast <cudaStream_t>(cuda_stream);
75+
76+ // Copy from host (CPU) to device (GPU) asynchronously
77+ cudaError_t cuda_err = cudaMemcpyAsync (output_data, input_data, bytes, cudaMemcpyHostToDevice, stream);
78+ if (cuda_err != cudaSuccess) {
79+ return ort_api.CreateStatus (ORT_EP_FAIL, cudaGetErrorString (cuda_err));
80+ }
81+
82+ return nullptr ;
83+ } catch (const Ort::Exception& ex) {
84+ Ort::Status status (ex);
85+ return status.release ();
86+ } catch (const std::exception& ex) {
87+ Ort::Status status (ex.what (), ORT_EP_FAIL);
88+ return status.release ();
89+ }
90+ }
91+
92+ OrtStatus* MemcpyToHost::ComputeImpl (OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept {
93+ try {
94+ const OrtApi& ort_api = Ort::GetApi ();
95+ const OrtValue* input_tensor = nullptr ;
96+ RETURN_IF_ERROR (ort_api.KernelContext_GetInput (kernel_ctx, 0 , &input_tensor));
97+
98+ // Get tensor shape and type
99+ OrtTensorTypeAndShapeInfo* tensor_info = nullptr ;
100+ RETURN_IF_ERROR (ort_api.GetTensorTypeAndShape (input_tensor, &tensor_info));
101+
102+ size_t num_dims = 0 ;
103+ RETURN_IF_ERROR (ort_api.GetDimensionsCount (tensor_info, &num_dims));
104+
105+ std::vector<int64_t > dims (num_dims);
106+ RETURN_IF_ERROR (ort_api.GetDimensions (tensor_info, dims.data (), num_dims));
107+ ort_api.ReleaseTensorTypeAndShapeInfo (tensor_info);
108+
109+ // Get output tensor
110+ OrtValue* output_tensor = nullptr ;
111+ RETURN_IF_ERROR (ort_api.KernelContext_GetOutput (kernel_ctx, 0 , dims.data (), num_dims, &output_tensor));
112+
113+ // Get data pointers
114+ const void * input_data = nullptr ;
115+ void * output_data = nullptr ;
116+ RETURN_IF_ERROR (ort_api.GetTensorData (input_tensor, &input_data));
117+ RETURN_IF_ERROR (ort_api.GetTensorMutableData (output_tensor, &output_data));
118+
119+ // Calculate size in bytes
120+ size_t bytes = 0 ;
121+ RETURN_IF_ERROR (ort_api.GetTensorSizeInBytes (input_tensor, &bytes));
122+
123+ // Get CUDA stream from kernel context
124+ void * cuda_stream = nullptr ;
125+ RETURN_IF_ERROR (ort_api.KernelContext_GetGPUComputeStream (kernel_ctx, &cuda_stream));
126+ cudaStream_t stream = static_cast <cudaStream_t>(cuda_stream);
127+
128+ // Copy from device (GPU) to host (CPU) asynchronously
129+ cudaError_t cuda_err = cudaMemcpyAsync (output_data, input_data, bytes, cudaMemcpyDeviceToHost, stream);
130+ if (cuda_err != cudaSuccess) {
131+ return ort_api.CreateStatus (ORT_EP_FAIL, cudaGetErrorString (cuda_err));
132+ }
133+
134+ return nullptr ;
135+ } catch (const Ort::Exception& ex) {
136+ Ort::Status status (ex);
137+ return status.release ();
138+ } catch (const std::exception& ex) {
139+ Ort::Status status (ex.what (), ORT_EP_FAIL);
140+ return status.release ();
141+ }
142+ }
143+
144+ ONNX_OPERATOR_KERNEL_EX (
145+ MemcpyFromHost,
146+ kOnnxDomain ,
147+ /* version*/ 1 , // Equivalent to start_version: 14, end_version: 14 (inclusive)
148+ (Ort::KernelDefBuilder()
149+ .SetInputMemType(0 , OrtMemType::OrtMemTypeCPUInput)
150+ .AddTypeConstraint(" T" , GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))),
151+ MemcpyFromHost)
152+
153+ ONNX_OPERATOR_KERNEL_EX (
154+ MemcpyToHost,
155+ kOnnxDomain ,
156+ /* version*/ 1 , // Equivalent to start_version: 14, end_version: 14 (inclusive)
157+ (Ort::KernelDefBuilder()
158+ .SetOutputMemType(0 , OrtMemType::OrtMemTypeCPUOutput)
159+ .AddTypeConstraint(" T" , GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))),
160+ MemcpyToHost)
161+
162+ } // namespace trt_ep
0 commit comments