diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 846c90bdd5..ec53fa928d 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -164,6 +164,9 @@ def run( timing_cache=None, profiling_verbosity=None, tactic_sources=None, + max_aux_streams=None, + version_compatible=False, + optimization_level=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -231,6 +234,18 @@ def run( if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) + + if trt.__version__ >= "8.6": + if max_aux_streams is not None: + _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") + builder_config.max_aux_streams = max_aux_streams + if version_compatible: + _LOGGER.info(f"Using version compatible") + builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + if optimization_level is not None: + _LOGGER.info(f"Using optimization level {optimization_level}") + builder_config.builder_optimization_level = optimization_level + if lower_precision == LowerPrecision.FP16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -265,6 +280,7 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) + _LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory") return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 6572fe9588..8e54bad234 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -142,6 +142,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, tactic_sources=self.lower_setting.tactic_sources, + max_aux_streams=self.lower_setting.max_aux_streams, + version_compatible=self.lower_setting.version_compatible, + optimization_level=self.lower_setting.optimization_level, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 07e7bf0dac..0bbd51ad12 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -74,6 +74,9 @@ class LowerSetting(LowerSettingBasic): correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + max_aux_streams: max number of aux stream to use + version_compatible: enable version compatible feature + optimization_level: builder optimization level """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -101,3 +104,6 @@ class LowerSetting(LowerSettingBasic): correctness_atol: float = 0.1 correctness_rtol: float = 0.1 use_experimental_rt: bool = False + max_aux_streams: Optional[int] = None + version_compatible: bool = False + optimization_level: Optional[int] = None diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 0b8578ffba..fedb45fdf3 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -151,7 +151,10 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. def validate_inference( - rtol=None, atol=None, run_alternative_batch_size: int = -1 + rtol=None, + atol=None, + suppress_accuracy_check_failure=True, + run_alternative_batch_size: int = -1, ) -> "Decorator": """ Returns a decorator on a PassFunc to sanity check the model outputs @@ -160,6 +163,7 @@ def validate_inference( Args: rtol: reletive tolerance atol: absoluate tolerance + suppress_accuracy_check_failure: accuracy check failure run_alternative_batch_size (int): In addition to running inference at original batch size in the input, also run at an alternative batch size. If set to -1, do not @@ -181,48 +185,51 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - res0 = module(*input) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input) - - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: - torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) + if suppress_accuracy_check_failure: + return pass_(module, input, *args, **kwargs) + else: + res0 = module(*input) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input) + + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module return pass_with_validation