-
Notifications
You must be signed in to change notification settings - Fork 308
Float8 autoquant weight only #866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
8e80b39
b8ab4ee
0ba6a2c
988af92
acb2afc
ce8ad06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
jainapurva marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,7 +69,10 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: | |
input = ( | ||
input.contiguous() | ||
) # (it seems the transpose makes cublas check the above j constraint on i) | ||
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe adding a comment to this would be helpful, how these two branches are handled? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The except is executed if it's a float8 dtype on H100, as there's no implementation for addmm_cuda for float8 dtypes. Added as comment |
||
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) | ||
except: | ||
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) | ||
else: | ||
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: | ||
""" | ||
|
Uh oh!
There was an error while loading. Please reload this page.