Skip to content

Conversation

@winglian
Copy link
Contributor

What does this PR do?

When loading pretrained models in 8-bit or 4-bit, the keep_in_fp32 kwarg is getting set as None, which the signature of a downstream function requires to be non-null as it uses that argument in an .extend() call.

/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:564: in from_pretrained
    return model_class.from_pretrained(
/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/modeling_utils.py:272: in _wrapper
    return func(*args, **kwargs)
/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/modeling_utils.py:4389: in from_pretrained
    hf_quantizer.preprocess_model(
/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/quantizers/base.py:215: in preprocess_model
    return self._process_model_before_weight_loading(model, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer object at 0x2ac940b28890>
model = LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-...05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=576, out_features=49152, bias=False)
)
device_map = {'': 0}, keep_in_fp32_modules = None, kwargs = {}
get_keys_to_not_convert = <function get_keys_to_not_convert at 0x2ac92e464680>
replace_with_bnb_linear = <function replace_with_bnb_linear at 0x2ac92e4667a0>
llm_int8_enable_fp32_cpu_offload = False

    def _process_model_before_weight_loading(
        self,
        model: "PreTrainedModel",
        device_map,
        keep_in_fp32_modules: List[str] = [],
        **kwargs,
    ):
        from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
    
        llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
    
        # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
        if self.quantization_config.llm_int8_skip_modules is None:
            self.modules_to_not_convert = get_keys_to_not_convert(model)
        else:
            self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
    
        if not isinstance(self.modules_to_not_convert, list):
            self.modules_to_not_convert = [self.modules_to_not_convert]
    
>       self.modules_to_not_convert.extend(keep_in_fp32_modules)
E       TypeError: 'NoneType' object is not iterable

/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/quantizers/quantizer_bnb_8bit.py:264: TypeError

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @SunMarc @muellerzr

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions bot marked this pull request as draft March 12, 2025 15:19
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@winglian winglian marked this pull request as ready for review March 12, 2025 15:19
@winglian
Copy link
Contributor Author

@Cyrilvallez looks like this is a side-effect of #36033 that was merged in today

@SunMarc
Copy link
Member

SunMarc commented Mar 12, 2025

Yeah this is a known issue ! It will be fixed in this PR ! #36672 (comment). Thanks for the report and the fix !

@Cyrilvallez
Copy link
Member

Indeed, this part should be the responsibility of the quantizer, not the main loading logic as it has default value None!
It is currently just an issue for bnb, thanks @SunMarc for taking care of it!

@ArthurZucker
Copy link
Collaborator

PR is merged, closing as fixed and handled!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants