Description
❓ Question
When trying to compile the FLUX.1-dev model using Torch-TensorRT following the official example/blog post, I'm encountering a ValueError
during the torch_tensorrt.dynamo.compile()
step. The error suggests there's an issue with input parsing where it's encountering a boolean value that it doesn't know how to handle.
What you have already tried
I'm following the exact steps from the example provided in the documentation (https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/torch_export_flux_dev.html). I've:
- Successfully loaded the FLUX.1-dev model
- Defined the dynamic shapes properly
- Created dummy inputs with the recommended dimensions
- Successfully exported the model using
_export
- Attempted to compile with Torch-TensorRT using the same parameters shown in the example
The error occurs specifically at the compilation step:
trt_gm = torch_tensorrt.dynamo.compile(
ep,
inputs=dummy_inputs,
enabled_precisions={torch.float32},
truncate_double=True,
min_block_size=1,
use_fp32_acc=True,
use_explicit_typing=True,
)
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- PyTorch Version (e.g., 1.0): 2.6.0
- CPU Architecture:
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.11.10
- CUDA version: cuda_12.4.r12.4/compiler.34097967_0
- GPU models and configuration: A100
- Any other relevant information:
Additional context
The error message specifically points to an issue with boolean input types:
ValueError: Invalid input type <class 'bool'> encountered in the dynamo_compile input parsing. Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}
It looks like the return_dict=False
parameter in my dummy inputs is causing the issue since it's a boolean value. The example shows that this should be supported, but the error suggests that booleans aren't handled correctly in the input parsing logic.
Full traceback:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/workspace/flux-dev-tensorrt.ipynb Cell 4 line 1
----> <a href='vscode-notebook-cell://ssh-remote%2B216.81.245.143/workspace/flux-dev-tensorrt.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> trt_gm = torch_tensorrt.dynamo.compile(
<a href='vscode-notebook-cell://ssh-remote%2B216.81.245.143/workspace/flux-dev-tensorrt.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a> ep,
<a href='vscode-notebook-cell://ssh-remote%2B216.81.245.143/workspace/flux-dev-tensorrt.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a> inputs=dummy_inputs,
<a href='vscode-notebook-cell://ssh-remote%2B216.81.245.143/workspace/flux-dev-tensorrt.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> enabled_precisions={torch.float32},
<a href='vscode-notebook-cell://ssh-remote%2B216.81.245.143/workspace/flux-dev-tensorrt.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a> truncate_double=True,
<a href='vscode-notebook-cell://ssh-remote%2B216.81.245.143/workspace/flux-dev-tensorrt.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a> min_block_size=1,
<a href='vscode-notebook-cell://ssh-remote%2B216.81.245.143/workspace/flux-dev-tensorrt.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a> use_fp32_acc=True,
<a href='vscode-notebook-cell://ssh-remote%2B216.81.245.143/workspace/flux-dev-tensorrt.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a> use_explicit_typing=True,
<a href='vscode-notebook-cell://ssh-remote%2B216.81.245.143/workspace/flux-dev-tensorrt.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a> )
File /usr/local/lib/python3.11/dist-packages/torch_tensorrt/dynamo/_compiler.py:606, in compile(exported_program, inputs, arg_inputs, kwarg_inputs, device, disable_tf32, assume_dynamic_shape_support, sparse_weights, enabled_precisions, engine_capability, debug, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, truncate_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules, pass_through_build_failures, max_aux_streams, version_compatible, optimization_level, use_python_runtime, use_fast_partitioner, enable_experimental_decompositions, dryrun, hardware_compatible, timing_cache_path, lazy_engine_init, cache_built_engines, reuse_cached_engines, engine_cache_dir, engine_cache_size, custom_engine_cache, use_explicit_typing, use_fp32_acc, refit_identical_engine_weights, strip_engine_weights, immutable_weights, enable_weight_streaming, **kwargs)
603 arg_inputs = [arg_inputs] # type: ignore
605 # Prepare torch_trt inputs
--> 606 trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
607 trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
608 device = to_torch_tensorrt_device(device)
File /usr/local/lib/python3.11/dist-packages/torch_tensorrt/dynamo/utils.py:257, in prepare_inputs(inputs, disable_memory_format_check)
255 torchtrt_input_list = []
256 for input_obj in inputs:
--> 257 torchtrt_input = prepare_inputs(
258 input_obj, disable_memory_format_check=disable_memory_format_check
259 )
260 torchtrt_input_list.append(torchtrt_input)
262 return (
263 torchtr