-
Notifications
You must be signed in to change notification settings - Fork 1.3k
check that an object has .float() method #3451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| return tensor.float() | ||
|
|
||
| def _is_fp16_bf16_tensor(tensor): | ||
| return (is_torch_tensor(tensor) or hasattr(tensor, "dtype")) and tensor.dtype in ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit surprised that cache objects would pass this check successfully. Could you please provide an example where this happens? It would be very useful to have that as a unit test. Moreover, is it intended that the is_torch_tensor(tensor) part was removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BenjaminBossan Any Torch tensor has a dtype attribute, so there is no need for is_torch_tensor. is_torch_tensor(tensor) or hasattr(tensor, "dtype") == hasattr(tensor, "dtype")
For the example, in transformers=4.50.0 the StaticCache, HybridCache, MambaCache and OffloadedStaticCache initialized with dtype=torch.bfloat16 pass the current test _is_fp16_bf16_tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any Torch tensor has a dtype attribute, so there is no need for
is_torch_tensor.is_torch_tensor(tensor) or hasattr(tensor, "dtype") == hasattr(tensor, "dtype")
Note that this was specifically added in #2387 to avoid hasattr(tensor, "dtype") for torch.compile. I'm not sure if this issue still persists with newer PyTorch versions, but let's not change it without verifying first.
For the example, in
transformers=4.50.0theStaticCache,HybridCache,MambaCacheandOffloadedStaticCacheinitialized withdtype=torch.bfloat16pass the current test_is_fp16_bf16_tensor
I see, the issue is that Cache objects seem to have a dtype attribute. I think it is okay to check for callable(getattr(tensor, "float", None)). But again, let's first check how this interacts with torch.compile. If it causes issues, we could instead do if not isinstance(tensor, Cache) and ....
|
Hi folks 👋 I don't think I've opened a PR in |
|
@gante The PR LGTM. Still, if an object has dtype, it needs to have a caster method, doesn't it? |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Fixes issue from huggingface/transformers#35873
Who can review