Skip to content

IndexError: tuple index out of range when using Tensor Parallelism with FSDP2 on GPT-OSS 20B (tensor_parallel.py, line 510) #41819

@JdRion

Description

@JdRion

System Info

  • transformers version: 4.57.1
  • Platform: Linux-5.15.92-2.el8.navix.ncc.x86_64-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • Huggingface_hub version: 0.35.3
  • Safetensors version: 0.5.2
  • Accelerate version: 1.11.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

@3outeille @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I’m using the official example script from the Accelerate repository:
examples/torch_native_parallelism/nd_parallel_trainer.py.

I only made two small changes to the example:
1. Changed the model name to GPT-OSS-20B
2. Added a quantization configuration

quantization_config = Mxfp4Config(dequantize=True)
model = AutoModelForCausalLM.from_pretrained(args.model_name, quantization_config=quantization_config, use_cache=False, **model_kwargs)
distributed_type: FSDP
mixed_precision: bf16

fsdp_config:
  fsdp_activation_checkpointing: false
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_cpu_ram_efficient_loading: false
  fsdp_offload_params: false
  fsdp_reshard_after_forward: true
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_version: 2

parallelism_config:
  parallelism_config_cp_size: 1
  parallelism_config_dp_replicate_size: 1
  parallelism_config_dp_shard_size: 16
  parallelism_config_tp_size: 8

Expected behavior

When training the GPT-OSS 20B model with Tensor Parallel (TP) and FSDP2, a runtime error occurs inside the tensor parallel integration code of transformers.
The error trace shows an IndexError: tuple index out of range raised from transformers/integrations/tensor_parallel.py during forward pass (specifically in _prepare_input_fn).

[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2328, in train
[rank2]:     return inner_training_loop(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2672, in _inner_training_loop
[rank2]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank2]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/trl/trainer/sft_trainer.py", line 1161, in training_step
[rank2]:     return super().training_step(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 4009, in training_step
[rank2]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/trl/trainer/sft_trainer.py", line 1079, in compute_loss
[rank2]:     (loss, outputs) = super().compute_loss(
[rank2]:                       ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 4099, in compute_loss
[rank2]:     outputs = model(**inputs)
[rank2]:               ^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank2]:     return inner()
[rank2]:            ^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1827, in inner
[rank2]:     result = forward_call(*args, **kwargs)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 940, in wrapper
[rank2]:     output = func(self, *args, **kwargs)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 663, in forward
[rank2]:     outputs: MoeModelOutputWithPast = self.model(
[rank2]:                                       ^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 1064, in wrapper
[rank2]:     outputs = func(self, *args, **kwargs)
[rank2]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 502, in forward
[rank2]:     hidden_states = decoder_layer(
[rank2]:                     ^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_layers.py", line 94, in __call__
[rank2]:     return super().__call__(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank2]:     return inner()
[rank2]:            ^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1827, in inner
[rank2]:     result = forward_call(*args, **kwargs)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 366, in forward
[rank2]:     hidden_states, _ = self.self_attn(
[rank2]:                        ^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank2]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 53, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/checkpoint.py", line 495, in checkpoint
[rank2]:     ret = function(*args, **kwargs)
[rank2]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank2]:     return inner()
[rank2]:            ^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1816, in inner
[rank2]:     args_result = hook(self, args)
[rank2]:                   ^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/integrations/tensor_parallel.py", line 381, in <lambda>
[rank2]:     module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
[rank2]:                                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformers/integrations/tensor_parallel.py", line 510, in _prepare_input_fn
[rank2]:     input_tensor = inputs[0]
[rank2]:                    ~~~~~~^^^
[rank2]: IndexError: tuple index out of range

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions