-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Open
Labels
Description
System Info
transformersversion: 4.57.1- Platform: Linux-5.15.92-2.el8.navix.ncc.x86_64-x86_64-with-glibc2.39
- Python version: 3.12.3
- Huggingface_hub version: 0.35.3
- Safetensors version: 0.5.2
- Accelerate version: 1.11.0
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA H100 80GB HBM3
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I’m using the official example script from the Accelerate repository:
examples/torch_native_parallelism/nd_parallel_trainer.py.
I only made two small changes to the example:
1. Changed the model name to GPT-OSS-20B
2. Added a quantization configuration
quantization_config = Mxfp4Config(dequantize=True)
model = AutoModelForCausalLM.from_pretrained(args.model_name, quantization_config=quantization_config, use_cache=False, **model_kwargs)
distributed_type: FSDP
mixed_precision: bf16
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_version: 2
parallelism_config:
parallelism_config_cp_size: 1
parallelism_config_dp_replicate_size: 1
parallelism_config_dp_shard_size: 16
parallelism_config_tp_size: 8
Expected behavior
When training the GPT-OSS 20B model with Tensor Parallel (TP) and FSDP2, a runtime error occurs inside the tensor parallel integration code of transformers.
The error trace shows an IndexError: tuple index out of range raised from transformers/integrations/tensor_parallel.py during forward pass (specifically in _prepare_input_fn).
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2328, in train
[rank2]: return inner_training_loop(
[rank2]: ^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2672, in _inner_training_loop
[rank2]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/trl/trainer/sft_trainer.py", line 1161, in training_step
[rank2]: return super().training_step(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 4009, in training_step
[rank2]: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/trl/trainer/sft_trainer.py", line 1079, in compute_loss
[rank2]: (loss, outputs) = super().compute_loss(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 4099, in compute_loss
[rank2]: outputs = model(**inputs)
[rank2]: ^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank2]: return inner()
[rank2]: ^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1827, in inner
[rank2]: result = forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 940, in wrapper
[rank2]: output = func(self, *args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 663, in forward
[rank2]: outputs: MoeModelOutputWithPast = self.model(
[rank2]: ^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 1064, in wrapper
[rank2]: outputs = func(self, *args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 502, in forward
[rank2]: hidden_states = decoder_layer(
[rank2]: ^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_layers.py", line 94, in __call__
[rank2]: return super().__call__(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank2]: return inner()
[rank2]: ^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1827, in inner
[rank2]: result = forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank2]: return func(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 366, in forward
[rank2]: hidden_states, _ = self.self_attn(
[rank2]: ^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank2]: return self.checkpoint_fn( # type: ignore[misc]
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 53, in inner
[rank2]: return disable_fn(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank2]: return fn(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/checkpoint.py", line 495, in checkpoint
[rank2]: ret = function(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank2]: return inner()
[rank2]: ^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1816, in inner
[rank2]: args_result = hook(self, args)
[rank2]: ^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/integrations/tensor_parallel.py", line 381, in <lambda>
[rank2]: module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformers/integrations/tensor_parallel.py", line 510, in _prepare_input_fn
[rank2]: input_tensor = inputs[0]
[rank2]: ~~~~~~^^^
[rank2]: IndexError: tuple index out of range