@@ -96,7 +96,8 @@ void setup_input_tensors(
9696 std::vector<at::Tensor> inputs,
9797 c10::intrusive_ptr<TRTEngine> compiled_engine,
9898 bool cudagraphs_enabled,
99- bool need_cudagraphs_record) {
99+ bool need_cudagraphs_record,
100+ bool shape_changed) {
100101 // this is a buffer to store shape tensor input addresses throughout the runtime scope
101102 std::list<std::vector<int64_t >> inputShapeTensorValues;
102103 std::list<at::Tensor> formatted_inputs (compiled_engine->num_io .first );
@@ -145,7 +146,7 @@ void setup_input_tensors(
145146 // Create a new persistent input buffer
146147 compiled_engine->input_buffers [i] = std::move (formatted_inputs.back ().clone ());
147148 }
148- if (need_cudagraphs_record ) {
149+ if (shape_changed ) {
149150 TORCHTRT_CHECK (
150151 compiled_engine->exec_ctx ->setInputShape (name.c_str (), dims), " Error while setting the input shape" );
151152 }
@@ -226,7 +227,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
226227 input_profiler_guard =
227228 std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path );
228229 }
229- setup_input_tensors (inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
230+ setup_input_tensors (inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, shape_changed );
230231 // Check if input shapes can be inferred.
231232 int32_t const io_size{compiled_engine->io_size };
232233 std::vector<char const *> names (io_size);
@@ -361,7 +362,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
361362 std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path );
362363 }
363364
364- setup_input_tensors (inputs, compiled_engine, false , false );
365+ setup_input_tensors (inputs, compiled_engine, false , false , true );
365366 // Check if input shapes can be inferred.
366367 int32_t const io_size{compiled_engine->cuda_engine ->getNbIOTensors ()};
367368 std::vector<char const *> names (io_size);
0 commit comments