Skip to content

🐛 [Bug] Transformers T5 Model does not compile via FX Path #1740

Closed
@gs-olive

Description

@gs-olive

Bug Description

When compiling the T5-Base Model model via the FX path, the following error is encountered. Note the model can be pre-traced using the HuggingFace symbolic tracer (Pre-Traced / NOT Pre-Traced below).

NOT Pre-Traced torch_tensorrt.fx.compile(model, is_aten=True,...)
[2023-03-16 00:51:26,256] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-03-16 00:51:40,315] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
[2023-03-16 00:51:40,640] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function dynamo_normalization_capturing_compiler
[2023-03-16 00:51:40,640] torch._dynamo.output_graph: [INFO] Step 2: done compiler function dynamo_normalization_capturing_compiler
[2023-03-16 00:51:41,904] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __getitem__
[2023-03-16 00:51:41,911] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing __getitem__ (RETURN_VALUE)
Traceback (most recent call last):
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 116, in dynamo_trace
    return torchdynamo.export(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 706, in export
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 118, in __call__
    return self.dynamo_ctx(self._orig_mod.__call__)(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 254, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/t5/modeling_t5.py", line 1395, in forward
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 391, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 105, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 263, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 383, in _compile
    hooks.guard_export_fn(output.guards)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 670, in guard_export_print
    assert out_guards is None, "whole graph export entails exactly one guard export"
AssertionError: whole graph export entails exactly one guard export
Pre-Traced torch_tensorrt.fx.compile(model, is_aten=True,...)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
    return fn(gm, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 262, in <lambda>
    trace_func=lambda module, inputs: aten_tracer.opt_trace(
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 159, in opt_trace
    pr: PassResult = passes(fx_module)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 447, in compose_bmm
    new_func,
UnboundLocalError: local variable 'new_func' referenced before assignment
Pre-Traced torch_tensorrt.fx.compile(model, is_aten=True,...) + PR #1708
Got 5 acc subgraphs and 6 non-acc subgraphs
Traceback (most recent call last):
  File "case_dict.py", line 217, in <module>
    T5MODEL()
  File "case_dict.py", line 135, in T5MODEL
    fx_trt_model = torch_tensorrt.fx.compile(traced, [input_ids, attention_mask, decoder_input_ids],
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
    out = _pass(out)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 202, in lower_func
    lowered_module = self._lower_func(
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 178, in lower_pass
    interp_res: TRTInterpreterResult = interpreter(mod, input, module_name)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 130, in __call__
    interp_result: TRTInterpreterResult = interpreter.run(
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 204, in run
    super().run()
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 137, in run
    self.env[node] = self.run_node(node)
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 275, in run_node
    trt_node = super().run_node(n)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 179, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 328, in call_function
    return converter(self.network, target, args, kwargs, self._cur_node_name)
  File "~/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py", line 57, in aten_ops_adaptive_avg_poolnd
    raise RuntimeError(f"We do not support {target} has dim={args[1]}")
RuntimeError: We do not support aten.mean.dim has dim=[-1]

While executing %mean_dim : [#users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_tensor_scalar, [-1], True), kwargs = {_itensor_to_tensor_meta: {<tensorrt.tensorrt.ITensor object at 0x7f0c8007ccf0>: None, <tensorrt.tensorrt.ITensor object at 0x7f0c803853b0>: ((1, 1, 1, 14), torch.float32, False, (14, 14, 14, 1), torch.channels_last, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f0c80354630>: None, <tensorrt.tensorrt.ITensor object at 0x7f0c80385a30>: ((1, 14, 768), torch.float32, True, (10752, 768, 1), torch.contiguous_format, False, {})}})
Original traceback:
  File "<eval_with_key>.0", line 24, in forward
    mean = pow_1.mean(-1, keepdim = True);  pow_1 = None

To Reproduce

Steps to reproduce the behavior:

  1. Initialize model: T5Model.from_pretrained("t5-base").eval().cuda()
  2. Initialize three input tensors, for example: torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") ("input_ids", "attention_mask", "decoder_input_ids")
  3. (Optional) Use the transformers tools to trace the model via: transformers.utils.fx.symbolic_trace(model, input_names=["input_ids", "attention_mask", "decoder_input_ids"])
  4. Compile the model using FX

Expected behavior

Model should compile via the FX path

Environment

  • Transformers: 4.26.1
  • Torch-TensorRT Version (e.g. 1.0.0): fce0a01
  • PyTorch Version (e.g. 1.0): 2.1.0.dev20230313+cu117
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.7

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions