@@ -174,6 +174,7 @@ def __init__(
174174 self .cudagraph : Optional [torch .cuda .CUDAGraph ] = None
175175 self ._caller_stream : Optional [torch .cuda .Stream ] = None
176176 self ._engine_stream : Optional [torch .cuda .Stream ] = None
177+ self .output_tensors : Optional [List [torch .Tensor ]] = None
177178
178179 # TODO: Make the below a Dictionary {shape: cudagraph}
179180 self .shape_key : Optional [str ] = None
@@ -218,10 +219,27 @@ def __init__(
218219 self .requires_output_allocator = requires_output_allocator
219220 self .output_allocator : Optional [DynamicOutputAllocator ] = None
220221 self .use_output_allocator_outputs = False
221-
222+ self .device = torch .cuda .current_device ()
223+ self .cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
224+ # If the output tensor is not owned by the engine (output_tensors_are_unowned=True), we need to create a new output tensor in each forward pass
225+ self .output_tensors_are_unowned = False
222226 if self .serialized_engine is not None and not self .settings .lazy_engine_init :
223227 self .setup_engine ()
224228
229+ def set_output_tensors_as_unowned (self , enabled : bool ) -> None :
230+ """
231+ Flag to set if the output tensors of this engine are solely owned by the Torch-TensorRT Runtime or if they might be shared with a user.
232+ If the tensors are not owned by the runtime, then they must be recreated on every forward call which may have implications for performance.
233+ Typically only the final engine in a graph requires output tensors to be unowned and there are performance gains to be had for intermediate engines to manage their own standing memory.
234+ Therefore this should only be set to True for the final module in a graph and leave false for intermediate modules.
235+
236+ Args:
237+ enabled: bool
238+ Whether to set the flag to True.
239+
240+ """
241+ self .output_tensors_are_unowned = enabled
242+
225243 def get_streamable_device_memory_budget (self ) -> Any :
226244 return self .engine .streamable_weights_size
227245
@@ -288,16 +306,25 @@ def setup_engine(self) -> None:
288306 for output_name in self .output_names
289307 ]
290308 self .output_shapes = [
291- self .engine .get_tensor_shape (output_name )
309+ tuple ( self .context .get_tensor_shape (output_name ) )
292310 for output_name in self .output_names
293311 ]
294312
313+ self .shape_key = "" .join (
314+ str (tuple (t )).replace (" " , "" ) for t in self .input_shapes
315+ )
316+
295317 if self .requires_output_allocator :
296318 self .create_output_allocator ()
297319
298320 if torch_tensorrt .runtime .get_cudagraphs_mode ():
299321 self .cudagraph = torch .cuda .CUDAGraph ()
300322
323+ self .is_shape_inference_io = {
324+ input_name : self .engine .is_shape_inference_io (input_name )
325+ for input_name in self .input_names
326+ }
327+
301328 def _check_initialized (self ) -> None :
302329 if not self .initialized :
303330 raise RuntimeError ("PythonTorchTensorRTModule is not initialized." )
@@ -383,16 +410,17 @@ def setup_input_tensors(
383410
384411 # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
385412 # as per TensorRT requirements
386- if self .engine . is_shape_inference_io ( input_name ) :
413+ if self .is_shape_inference_io [ input_name ] :
387414 # Shape tensor inputs are casted to int64 explicitly
388415 # Currently Torch CPU pointers are not working; numpy pointers are used instead
389416 # to refer to underlying memory
390417 inputs_cpu = contiguous_inputs [i ].cpu ().to (torch .int64 ).numpy ().copy ()
391418 self .context .set_tensor_address (input_name , inputs_cpu .ctypes .data )
392419 else :
393- self .context .set_input_shape (
394- input_name , tuple (contiguous_inputs [i ].shape )
395- )
420+ if need_cudagraphs_record :
421+ self .context .set_input_shape (
422+ input_name , tuple (contiguous_inputs [i ].shape )
423+ )
396424 if cudagraphs_enabled :
397425 self ._input_buffers [i ].copy_ (contiguous_inputs [i ])
398426 self .context .set_tensor_address (
@@ -411,7 +439,7 @@ def create_output_tensors(self) -> List[torch.Tensor]:
411439 output = torch .empty (
412440 size = self .output_shapes [o ],
413441 dtype = self .output_dtypes [o ],
414- device = torch . cuda . current_device () ,
442+ device = self . device ,
415443 )
416444 outputs .append (output )
417445 return outputs
@@ -460,7 +488,9 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
460488 ), f"Wrong number of inputs, expect { len (self .input_names )} get { len (contiguous_inputs )} ."
461489
462490 self .setup_input_tensors (
463- contiguous_inputs , self .cudagraphs_enabled , need_cudagraphs_record
491+ contiguous_inputs ,
492+ self .cudagraphs_enabled ,
493+ need_cudagraphs_record ,
464494 )
465495
466496 if shape_changed :
@@ -482,15 +512,22 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
482512 if can_use_pre_allocated_outputs :
483513 outputs = self .pre_allocated_outputs
484514 else :
485- self .output_shapes = [
486- tuple (self .context .get_tensor_shape (output_name ))
487- for output_name in self .output_names
488- ]
515+ if shape_changed or self .output_tensors is None :
516+ self .output_shapes = [
517+ tuple (self .context .get_tensor_shape (output_name ))
518+ for output_name in self .output_names
519+ ]
489520 if DYNAMIC_DIM in self .output_shapes :
490521 raise ValueError (
491522 "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
492523 )
493- outputs = self .create_output_tensors ()
524+ if (
525+ self .output_tensors is None
526+ or self .output_tensors_are_unowned
527+ or shape_changed
528+ ):
529+ self .output_tensors = self .create_output_tensors ()
530+ outputs = self .output_tensors
494531
495532 for o , output_name in enumerate (self .output_names ):
496533 if need_cudagraphs_record :
@@ -751,13 +788,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
751788 # Representation of input shapes to a given model
752789 # Shapes are concatenated as so:
753790 # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
754- tensor_inputs = []
755- for t in inputs :
756- if not isinstance (t , torch .Tensor ):
757- return True
758- tensor_inputs .append (t )
791+ if not all (isinstance (t , torch .Tensor ) for t in inputs ):
792+ return True
793+
759794 new_shape_key = "" .join (
760- str (tuple (t .shape )).replace (" " , "" ) for t in tensor_inputs
795+ str (tuple (t .shape )).replace (" " , "" )
796+ for t in inputs
797+ if isinstance (t , torch .Tensor )
761798 )
762799
763800 # If the new shape key differs from the existing one,
0 commit comments