Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Float8Linear does not support autocast #257

Closed
yitzhaklevi opened this issue May 6, 2024 · 2 comments
Closed

Float8Linear does not support autocast #257

yitzhaklevi opened this issue May 6, 2024 · 2 comments

Comments

@yitzhaklevi
Copy link

The issue is caused due to the fact that Float8Linear captures the input dtype (via -> https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/float8_linear.py#L303) , And later we have this assert (during sync_float8_amax_and_scale_history - https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/float8_linear_utils.py#L247) that causes the failure.

One trivial solution would be to use https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd (with cast_inputs=torch.get_autocast_gpu_dtype())

The following script reproduces the issue (run without args) and the trivial solution (add --wrap_linear_layer)

import torch
from float8_experimental.float8_linear_utils import (
    sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_linear_utils import (
    swap_linear_with_float8_linear
)
from float8_experimental.float8_linear import Float8Linear as BaseFloat8Linear
from torch import get_autocast_gpu_dtype
from torch.cuda.amp import custom_fwd
import argparse


def get_args():
    p = argparse.ArgumentParser()
    p.add_argument('--wrap_linear_layer', dest="wrap_linear_layer", action="store_true")
    return p.parse_args()


class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(16, 16)
        self.l2 = torch.nn.Linear(16, 32)
        # norm layer is just an example - but can be any layer that outputs float32 regardless of autocast settings
        self.norm_layer = torch.nn.LayerNorm(32) 
        self.l3 = torch.nn.Linear(32, 16)

    def forward(self, x):
        # x is still float32
        x = self.l1(x)
        # x is now bfloat16
        x = self.l2(x)
        # x is still bfloat16
        x = self.norm_layer(x)
        # x is now float32 (since the output of norm layer is float32 regardless of autocast settings)
        x = self.l3(x)
        # x is now bfloat16
        return x


if __name__ == '__main__':
    args = get_args()
    m = SimpleModel().to('cuda')
    if args.wrap_linear_layer:
        class Float8Linear(BaseFloat8Linear):
            @custom_fwd(cast_inputs=get_autocast_gpu_dtype())
            def forward(self, *args, **kwargs):
                return super().forward(*args, **kwargs)
    else:
        Float8Linear = BaseFloat8Linear

    swap_linear_with_float8_linear(m, Float8Linear)
    b = torch.rand([17, 16]).to('cuda')
    with torch.amp.autocast(enabled=True, device_type='cuda', dtype=torch.bfloat16):
        out = m(b)
        sync_float8_amax_and_scale_history(m)
    print('Done !')
@vkuzo
Copy link
Contributor

vkuzo commented May 28, 2024

Thanks for the report, we should be able to enable this by refactoring the sync_float8_amax_and_scale_history function and removing the assert you linked.

@vkuzo
Copy link
Contributor

vkuzo commented Jul 30, 2024

pytorch/ao#568

@vkuzo vkuzo closed this as completed Jul 30, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants