Skip to content

Symbolic trace failed when using HFTracer with Qwen3Model. AttributeError: 'Tensor' object has no attribute 'get'. #43142

@tongl2

Description

@tongl2

System Info

  • transformers version: 4.57.3
  • Platform: Linux-6.8.0-71-generic-x86_64-with-glibc2.39
  • Python version: 3.11.14
  • Huggingface_hub version: 0.36.0
  • Safetensors version: 0.7.0
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0+cu129 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA H20

Who can help?

@Cyrilvallez I'm not quite sure with the maintainer of utils.fx module, correct me pls if I tag the wrong maintainer.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Run script
from transformers.utils.fx import HFTracer, symbolic_trace
from transformers.modeling_utils import no_init_weights
import torch
from torch.fx import GraphModule

cfg = AutoConfig.from_pretrained('qwen3_8b_config.json')
with no_init_weights():
    model = AutoModel.from_config(cfg).eval()

dummy_input = {
    "input_ids": torch.randint(0, 1000, (1, 64)),
    "attention_mask": torch.ones(1, 64, dtype=torch.long),
    'position_ids': None,
    'past_key_values': None,
    'inputs_embeds': None,
    'use_cache': None,
    'cache_position': None,
}

# For debug:
import torch
import traceback
def _tensor_get(self, key, default=None):
    print("\n⚠️  A Tensor was mistakenly called with .get()!")
    print(f'{type(self)=}')
    print(f'{self=}')
    print("Tensor shape:", self.shape, "dtype:", self.dtype)
    print("Call stack:")
    traceback.print_stack()
    raise AttributeError(f"'Tensor' object has no attribute 'get'. This tensor was used at the above location.")
torch.Tensor.get = _tensor_get

tracer = HFTracer()
graph = tracer.trace(model, concrete_args=dummy_input)
graph_module = GraphModule(model, graph)

graph_module.graph.print_tabular()
  1. Got Result
/***/opt/miniconda3/envs/torch-primer-311/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py:924: UserWarning: Was not able to add assertion to guarantee correct input input_ids to specialized function. It is up to the user to make sure that your inputs match the inputs you specialized the function with.
  warnings.warn(
/***/opt/miniconda3/envs/torch-primer-311/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py:924: UserWarning: Was not able to add assertion to guarantee correct input attention_mask to specialized function. It is up to the user to make sure that your inputs match the inputs you specialized the function with.
  warnings.warn(

⚠️  A Tensor was mistakenly called with .get()!
type(self)=<class 'torch.Tensor'>
self=tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
Tensor shape: torch.Size([1, 64]) dtype: torch.int64
Call stack:
  File "/***/workspace/torch_primer/code/fx/fx_qwen3_next_new.py", line 35, in <module>
    graph = tracer.trace(model, concrete_args=dummy_input)
  File "/***/opt/miniconda3/envs/torch-primer-311/lib/python3.11/site-packages/transformers/utils/fx.py", line 1316, in trace
    self.graph = super().trace(root, concrete_args=concrete_args)
  File "/***/opt/miniconda3/envs/torch-primer-311/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 851, in trace
    (self.create_arg(fn(*args)),),
  File "/***/opt/miniconda3/envs/torch-primer-311/lib/python3.11/site-packages/transformers/utils/generic.py", line 970, in wrapper
    )
  File "/***/workspace/torch_primer/code/fx/fx_qwen3_next_new.py", line 30, in _tensor_get
    traceback.print_stack()
Traceback (most recent call last):
  File "/***/workspace/torch_primer/code/fx/fx_qwen3_next_new.py", line 35, in <module>
    graph = tracer.trace(model, concrete_args=dummy_input)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/***/opt/miniconda3/envs/torch-primer-311/lib/python3.11/site-packages/transformers/utils/fx.py", line 1316, in trace
    self.graph = super().trace(root, concrete_args=concrete_args)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/***/opt/miniconda3/envs/torch-primer-311/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 851, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/***/opt/miniconda3/envs/torch-primer-311/lib/python3.11/site-packages/transformers/utils/generic.py", line 970, in wrapper
    )
^^^^^
  File "/***/workspace/torch_primer/code/fx/fx_qwen3_next_new.py", line 31, in _tensor_get
    raise AttributeError(f"'Tensor' object has no attribute 'get'. This tensor was used at the above location.")
AttributeError: 'Tensor' object has no attribute 'get'. This tensor was used at the above location.

Expected behavior

graph_module.graph.print_tabular() is supposed to work correctly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions