Skip to content

Conversation

@eljandoubi
Copy link
Contributor

@eljandoubi eljandoubi commented Mar 20, 2025

What does this PR do?

Fixes issue from huggingface/transformers#35873

Who can review

return tensor.float()

def _is_fp16_bf16_tensor(tensor):
return (is_torch_tensor(tensor) or hasattr(tensor, "dtype")) and tensor.dtype in (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised that cache objects would pass this check successfully. Could you please provide an example where this happens? It would be very useful to have that as a unit test. Moreover, is it intended that the is_torch_tensor(tensor) part was removed?

Copy link
Contributor Author

@eljandoubi eljandoubi Mar 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan Any Torch tensor has a dtype attribute, so there is no need for is_torch_tensor. is_torch_tensor(tensor) or hasattr(tensor, "dtype") == hasattr(tensor, "dtype")

For the example, in transformers=4.50.0 the StaticCache, HybridCache, MambaCache and OffloadedStaticCache initialized with dtype=torch.bfloat16 pass the current test _is_fp16_bf16_tensor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any Torch tensor has a dtype attribute, so there is no need for is_torch_tensor. is_torch_tensor(tensor) or hasattr(tensor, "dtype") == hasattr(tensor, "dtype")

Note that this was specifically added in #2387 to avoid hasattr(tensor, "dtype") for torch.compile. I'm not sure if this issue still persists with newer PyTorch versions, but let's not change it without verifying first.

For the example, in transformers=4.50.0 the StaticCache, HybridCache, MambaCache and OffloadedStaticCache initialized with dtype=torch.bfloat16 pass the current test _is_fp16_bf16_tensor

I see, the issue is that Cache objects seem to have a dtype attribute. I think it is okay to check for callable(getattr(tensor, "float", None)). But again, let's first check how this interacts with torch.compile. If it causes issues, we could instead do if not isinstance(tensor, Cache) and ....

@gante
Copy link

gante commented Mar 28, 2025

Hi folks 👋

I don't think accelerate should change, it makes reasonable assumptions about the objects: if it has a dtype, it can be casted.

I've opened a PR in transformers that moves that attribute to private (it should have never been public in the first place): huggingface/transformers#37044

@eljandoubi
Copy link
Contributor Author

@gante The PR LGTM. Still, if an object has dtype, it needs to have a caster method, doesn't it?

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Apr 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants