Skip to content

Commit fa01628

Browse files
committed
fix bugs
1 parent cd61e54 commit fa01628

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,8 @@ def run(
313313
)
314314
timing_cache = self._create_timing_cache(builder_config, existing_cache)
315315

316-
engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
317-
assert engine
316+
serialized_engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
317+
assert serialized_engine
318318

319319
serialized_cache = (
320320
bytearray(timing_cache.serialize())
@@ -324,10 +324,10 @@ def run(
324324
_LOGGER.info(
325325
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
326326
)
327-
_LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory")
327+
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
328328

329329
return TRTInterpreterResult(
330-
engine, self._input_names, self._output_names, serialized_cache
330+
serialized_engine, self._input_names, self._output_names, serialized_cache
331331
)
332332

333333
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
2929

3030
def __init__(
3131
self,
32-
engine: trt.ICudaEngine,
32+
engine: trt.tensorrt.IHostMemory,
3333
input_names: Optional[List[str]] = None,
3434
output_names: Optional[List[str]] = None,
3535
target_device: Device = Device._current_device(),
@@ -60,9 +60,7 @@ def _initialize(self) -> None:
6060
self.engine = runtime.deserialize_cuda_engine(self.engine)
6161
self.context = self.engine.create_execution_context()
6262

63-
assert (
64-
self.engine.num_io_tensors // self.engine.num_optimization_profiles
65-
) == (len(self.input_names) + len(self.output_names))
63+
assert self.engine.num_io_tensors == (len(self.input_names) + len(self.output_names))
6664

6765
self.input_dtypes = [
6866
dtype._from(self.engine.get_tensor_dtype(input_name))

0 commit comments

Comments
 (0)