You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
The following script reproduces the issue (run without args) and the trivial solution (add --wrap_linear_layer)
import torch
from float8_experimental.float8_linear_utils import (
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear
)
from float8_experimental.float8_linear import Float8Linear as BaseFloat8Linear
from torch import get_autocast_gpu_dtype
from torch.cuda.amp import custom_fwd
import argparse
def get_args():
p = argparse.ArgumentParser()
p.add_argument('--wrap_linear_layer', dest="wrap_linear_layer", action="store_true")
return p.parse_args()
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(16, 16)
self.l2 = torch.nn.Linear(16, 32)
# norm layer is just an example - but can be any layer that outputs float32 regardless of autocast settings
self.norm_layer = torch.nn.LayerNorm(32)
self.l3 = torch.nn.Linear(32, 16)
def forward(self, x):
# x is still float32
x = self.l1(x)
# x is now bfloat16
x = self.l2(x)
# x is still bfloat16
x = self.norm_layer(x)
# x is now float32 (since the output of norm layer is float32 regardless of autocast settings)
x = self.l3(x)
# x is now bfloat16
return x
if __name__ == '__main__':
args = get_args()
m = SimpleModel().to('cuda')
if args.wrap_linear_layer:
class Float8Linear(BaseFloat8Linear):
@custom_fwd(cast_inputs=get_autocast_gpu_dtype())
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
else:
Float8Linear = BaseFloat8Linear
swap_linear_with_float8_linear(m, Float8Linear)
b = torch.rand([17, 16]).to('cuda')
with torch.amp.autocast(enabled=True, device_type='cuda', dtype=torch.bfloat16):
out = m(b)
sync_float8_amax_and_scale_history(m)
print('Done !')
The text was updated successfully, but these errors were encountered:
Thanks for the report, we should be able to enable this by refactoring the sync_float8_amax_and_scale_history function and removing the assert you linked.
The issue is caused due to the fact that Float8Linear captures the input dtype (via -> https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/float8_linear.py#L303) , And later we have this assert (during sync_float8_amax_and_scale_history - https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/float8_linear_utils.py#L247) that causes the failure.
One trivial solution would be to use https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd (with cast_inputs=torch.get_autocast_gpu_dtype())
The following script reproduces the issue (run without args) and the trivial solution (add --wrap_linear_layer)
The text was updated successfully, but these errors were encountered: