Skip to content

Commit 9953f22

Browse files
Add --fast argument to enable experimental optimizations.
Optimizations that might break things/lower quality will be put behind this flag first and might be enabled by default in the future. Currently the only optimization is float8_e4m3fn matrix multiplication on 4000/ADA series Nvidia cards or later. If you have one of these cards you will see a speed boost when using fp8_e4m3fn flux for example.
1 parent d1a6bd6 commit 9953f22

File tree

4 files changed

+52
-5
lines changed

4 files changed

+52
-5
lines changed

comfy/cli_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class LatentPreviewMethod(enum.Enum):
123123

124124
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
125125
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
126+
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
126127

127128
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
128129
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

comfy/model_base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
9696

9797
if not unet_config.get("disable_unet_model_creation", False):
9898
if model_config.custom_operations is None:
99-
if self.manual_cast_dtype is not None:
100-
operations = comfy.ops.manual_cast
101-
else:
102-
operations = comfy.ops.disable_weight_init
99+
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
103100
else:
104101
operations = model_config.custom_operations
105102
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)

comfy/model_management.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
10481048

10491049
return False
10501050

1051+
def supports_fp8_compute(device=None):
1052+
props = torch.cuda.get_device_properties(device)
1053+
if props.major >= 9:
1054+
return True
1055+
if props.major < 8:
1056+
return False
1057+
if props.minor < 9:
1058+
return False
1059+
return True
1060+
10511061
def soft_empty_cache(force=False):
10521062
global cpu_state
10531063
if cpu_state == CPUState.MPS:

comfy/ops.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020
import comfy.model_management
21-
21+
from comfy.cli_args import args
2222

2323
def cast_to(weight, dtype=None, device=None, non_blocking=False):
2424
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
@@ -242,3 +242,42 @@ class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
242242

243243
class Embedding(disable_weight_init.Embedding):
244244
comfy_cast_weights = True
245+
246+
247+
def fp8_linear(self, input):
248+
dtype = self.weight.dtype
249+
if dtype not in [torch.float8_e4m3fn]:
250+
return None
251+
252+
if len(input.shape) == 3:
253+
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
254+
inn = input.to(dtype)
255+
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
256+
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
257+
for i in range(input.shape[0]):
258+
if self.bias is not None:
259+
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
260+
else:
261+
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype)
262+
out[i] = o
263+
return out
264+
return None
265+
266+
class fp8_ops(manual_cast):
267+
class Linear(manual_cast.Linear):
268+
def forward_comfy_cast_weights(self, input):
269+
out = fp8_linear(self, input)
270+
if out is not None:
271+
return out
272+
273+
weight, bias = cast_bias_weight(self, input)
274+
return torch.nn.functional.linear(input, weight, bias)
275+
276+
277+
def pick_operations(weight_dtype, compute_dtype, load_device=None):
278+
if compute_dtype is None or weight_dtype == compute_dtype:
279+
return disable_weight_init
280+
if args.fast:
281+
if comfy.model_management.supports_fp8_compute(load_device):
282+
return fp8_ops
283+
return manual_cast

0 commit comments

Comments
 (0)