Skip to content

Commit 83125e9

Browse files
authored
macOS fixes
2 parents a535ab7 + 1499393 commit 83125e9

File tree

4 files changed

+28
-13
lines changed

4 files changed

+28
-13
lines changed

modules/module/quantized/LinearGGUFA8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
from modules.util.mm_8bit import mm_8bit as mm_8bit
12
from modules.util.quantization_util import (
23
quantize_fp8_axiswise,
34
quantize_int8_axiswise,
45
)
5-
from modules.util.triton_mm_8bit import mm_8bit as triton_mm_8bit
66

77
import torch
88
from torch import Tensor
@@ -39,14 +39,14 @@ def fp8_forward_axiswise(x: Tensor, weight: Tensor, bias: Tensor | None, compute
3939
def int8_backward_axiswise(output: Tensor, weight: Tensor) -> Tensor:
4040
output_8, output_scale = quantize_int8_axiswise(output, dim=-1)
4141
w_8, w_scale = quantize_int8_axiswise(weight, dim=0)
42-
mm_res = triton_mm_8bit(output_8.contiguous(), w_8)
42+
mm_res = mm_8bit(output_8.contiguous(), w_8)
4343
return mm_res.float().mul_(w_scale).mul_(output_scale).to(output.dtype)
4444

4545
@torch.no_grad()
4646
def fp8_backward_axiswise(output: Tensor, weight: Tensor) -> Tensor:
4747
output_8, output_scale = quantize_fp8_axiswise(output, dim=-1)
4848
w_8, w_scale = quantize_fp8_axiswise(weight, dim=0)
49-
mm_res = triton_mm_8bit(output_8.contiguous(), w_8)
49+
mm_res = mm_8bit(output_8.contiguous(), w_8)
5050
return mm_res.float().mul_(w_scale).mul_(output_scale).to(output.dtype)
5151

5252
class LinearGGUFIntA8RequantFunction(torch.autograd.Function):

modules/module/quantized/LinearW8A8.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11

22
from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin
33
from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin
4+
from modules.util.mm_8bit import mm_8bit as mm_8bit
45
from modules.util.quantization_util import (
56
dequantize,
67
quantize_fp8_axiswise,
78
quantize_fp8_tensorwise,
89
quantize_int8_axiswise,
910
quantize_int8_tensorwise,
1011
)
11-
from modules.util.triton_mm_8bit import mm_8bit as triton_mm_8bit
1212

1313
import torch
1414
from torch import Tensor, nn
@@ -37,13 +37,13 @@ def fp8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias:
3737
def int8_backward_axiswise(output: Tensor, weight: Tensor, weight_scale: float) -> Tensor:
3838
output_8, output_scale = quantize_int8_axiswise(output, dim=-1)
3939
#almost always, grad outputs are already contiguous and this is a no-op. But there are some grad outputs from SDXL that are non-contiguous:
40-
mm_res = triton_mm_8bit(output_8.contiguous(), weight)
40+
mm_res = mm_8bit(output_8.contiguous(), weight)
4141
return mm_res.float().mul_(weight_scale * output_scale).to(output.dtype)
4242

4343
@torch.no_grad()
4444
def fp8_backward_axiswise(output: Tensor, weight: Tensor, weight_scale: float) -> Tensor:
4545
output_8, output_scale = quantize_fp8_axiswise(output, dim=-1)
46-
mm_res = triton_mm_8bit(output_8.contiguous(), weight)
46+
mm_res = mm_8bit(output_8.contiguous(), weight)
4747
return mm_res.float().mul_(weight_scale * output_scale).to(output.dtype)
4848

4949

@@ -158,11 +158,11 @@ def benchmark_int8(m, k, n, device = 'cuda'):
158158

159159

160160
run_benchmark(lambda: torch._int_mm(x_8, w_8.T), "torch mm int")
161-
run_benchmark(lambda: triton_mm_8bit(x_8, w_8.T), "triton mm int")
161+
run_benchmark(lambda: mm_8bit(x_8, w_8.T), "triton mm int")
162162
def torch_backward(a, b):
163163
torch._int_mm(a, b.T.contiguous().T)
164164
run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward int8")
165-
run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward int8")
165+
run_benchmark(lambda: mm_8bit(y_8, w_8), "triton mm backward int8")
166166

167167
run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale), "torch forward int", compile=True)
168168
run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale), "triton backward int", compile=True)
@@ -179,11 +179,11 @@ def benchmark_fp8(m, k, n, device = 'cuda'):
179179
one_scale = torch.ones(1, device=device)
180180

181181
run_benchmark(lambda: torch._scaled_mm(x_8, w_8.T, out_dtype=torch.bfloat16, scale_a=one_scale.float(), scale_b=w_scale.float()), "torch mm fp8")
182-
run_benchmark(lambda: triton_mm_8bit(x_8, w_8.T), "triton mm fp8")
182+
run_benchmark(lambda: mm_8bit(x_8, w_8.T), "triton mm fp8")
183183
def torch_backward(a, b):
184184
torch._scaled_mm(a, b.T.contiguous().T, out_dtype=torch.bfloat16, scale_a=one_scale.float(), scale_b=w_scale.float())
185185
run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward fp8")
186-
run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward fp8")
186+
run_benchmark(lambda: mm_8bit(y_8, w_8), "triton mm backward fp8")
187187
run_benchmark(lambda: fp8_forward_tokenwise(x, w_8, w_scale), "torch forward fp8", compile=True)
188188
run_benchmark(lambda: fp8_backward_axiswise(y, w_8, w_scale), "triton backward fp8", compile=True)
189189

modules/util/dtype_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def create_autocast_context(
3333
) -> tuple[torch.autocast | nullcontext, DataType]:
3434
if torch.backends.mps.is_available():
3535
if any(train_dtype != dt for dt in weight_dtypes if dt is not None):
36-
raise RuntimeError("macOS needs all dtypes to be the same.")
37-
38-
return nullcontext(), train_dtype
36+
print("Warning: Mixed precision training is untested on macOS. Consider setting all dtypes to be the same.")
37+
else:
38+
return nullcontext(), train_dtype
3939

4040
weight_dtypes = list(weight_dtypes)
4141
weight_dtypes = list(filter(lambda dtype: dtype != DataType.NONE and dtype is not None, weight_dtypes))

modules/util/mm_8bit.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
try:
2+
from modules.util.triton_mm_8bit import mm_8bit
3+
except ImportError as e:
4+
print(str(e) + ", continueing without triton")
5+
import torch
6+
def mm_8bit(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
7+
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
8+
assert a.is_contiguous(), "Matrix A must be contiguous"
9+
assert a.dtype == b.dtype, "Incompatible dtypes"
10+
assert a.dtype in [torch.int8, torch.float8_e4m3fn]
11+
if a.dtype == torch.int8:
12+
return torch._int_mm(a, b)
13+
else:
14+
one = torch.ones(1, device=a.device)
15+
return torch._scaled_mm(a, b.T.contiguous().T, scale_a=one, scale_b=one)

0 commit comments

Comments
 (0)