diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index a329bbe28a..a7337f4f8e 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -98,6 +98,13 @@ def lift( ) assert fake_mode is not None + # This map stores the names of outputs (old to new) + # This is necessary to track because the output names can be changed when + # we convert graph constants to placeholder inputs below. + output_names = {} + for output_spec in graph_signature.output_specs: + output_names[output_spec.arg.name] = output_spec.arg.name + # Locate the user input to insert new placeholders before them first_user_input = None for node in gm.graph.nodes: @@ -139,9 +146,8 @@ def lift( # Replace get_attr nodes with placeholder nodes and copy metadata. with gm.graph.inserting_before(first_user_input): # Ensure name doesn't contain period as it is used for submodules - const_placeholder_node = gm.graph.placeholder( - node.target.replace(".", "_") - ) + const_placeholder_name = node.target.replace(".", "_") + const_placeholder_node = gm.graph.placeholder(const_placeholder_name) # Copy the node meta into this new placeholder node const_placeholder_node.meta = node.meta @@ -157,6 +163,12 @@ def lift( node.replace_all_uses_with(const_placeholder_node) gm.graph.erase_node(node) + # Verify if the const_placeholder being added is one of the output nodes + # This happens if there is just a single static arange op in the graph + # https://github.com/pytorch/TensorRT/issues/3189 + if const_placeholder_name in output_names: + output_names[const_placeholder_name] = const_placeholder_node.name + # Add these parameters/buffers/constants to the existing graph signature # before user inputs. These specs are looked up in the state_dict during ExportedProgram creation. input_spec_arg = TensorArgument(name=const_placeholder_node.name) @@ -174,6 +186,11 @@ def lift( ) non_user_input_idx += 1 + # Update output_specs with modified names. This only gets updated if the graph getattr nodes (weights) + # are also the outputs of the graph + for output_spec in graph_signature.output_specs: + output_spec.arg.name = output_names[output_spec.arg.name] + gm.graph.eliminate_dead_code() gm.graph.lint() diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 146cc2addf..470da496ba 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -381,6 +381,64 @@ def forward(self, x): ) +@pytest.mark.unit +def test_arange_export(ir): + """ + This tests export save and load functionality on a arange static graph + Here the arange output is a static constant (which is registered as input to the graph) + in the exporter. + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x_embed = torch.arange( + 1, x.shape[-1] + 1, dtype=torch.float32, device=x.device + ) + return x_embed + + model = MyModule().eval().cuda() + input = torch.randn((1, 1, 128, 128)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + + torchtrt.save(trt_module, trt_ep_path, inputs=[input]) + + deser_trt_module = torchtrt.load(trt_ep_path).module() + outputs_pyt = model(input) + outputs_trt = trt_module(input) + + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_arange_export TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + outputs_trt_deser = deser_trt_module(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_arange_export deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + @pytest.mark.unit def test_save_load_ts(ir): """