Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- Latent previews with [TAESD](https://github.com/madebyollin/taesd)
- Starts up very fast.
- Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
Expand Down Expand Up @@ -181,6 +182,10 @@ You can set this command line setting to disable the upcasting to fp32 in some c

```--dont-upcast-attention```

## How to show high-quality previews?

The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/taesd` folder. Once they're installed, restart ComfyUI to enable high-quality previews.

## Support and dev channel

[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
Expand Down
38 changes: 38 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,35 @@
import argparse
import enum


class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)

# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")

# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")

super(EnumAction, self).__init__(**kwargs)

self._enum = enum_type

def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)


parser = argparse.ArgumentParser()

Expand All @@ -13,6 +44,13 @@
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")

class LatentPreviewMethod(enum.Enum):
Auto = "auto"
Latent2RGB = "latent2rgb"
TAESD = "taesd"
parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.")
parser.add_argument("--default-preview-method", type=str, default=LatentPreviewMethod.Auto, metavar="PREVIEW_METHOD", help="Default preview method for sampler nodes.")

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
Expand Down
65 changes: 65 additions & 0 deletions comfy/taesd/taesd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
"""
import torch
import torch.nn as nn

def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)

class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3

class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))

def Encoder():
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4),
)

def Decoder():
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)

class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5

def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu"))
if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu"))

@staticmethod
def scale_latents(x):
"""raw latents -> [0, 1]"""
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)

@staticmethod
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
7 changes: 5 additions & 2 deletions comfy/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import math
import struct
import comfy.model_management

def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
Expand Down Expand Up @@ -166,6 +167,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
comfy.model_management.throw_exception_if_processing_interrupted()

s_in = s[:,:,y:y+tile_y,x:x+tile_x]

ps = function(s_in).cpu()
Expand Down Expand Up @@ -197,14 +200,14 @@ def __init__(self, total):
self.current = 0
self.hook = PROGRESS_BAR_HOOK

def update_absolute(self, value, total=None):
def update_absolute(self, value, total=None, preview=None):
if total is not None:
self.total = total
if value > self.total:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total)
self.hook(self.current, self.total, preview)

def update(self, value):
self.update_absolute(self.current + value)
1 change: 1 addition & 0 deletions folder_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["taesd"] = ([os.path.join(models_dir, "taesd")], supported_pt_extensions)

folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
Expand Down
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import execution
import folder_paths
import server
from server import BinaryEventTypes
from nodes import init_custom_nodes


Expand All @@ -40,8 +41,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())

def hijack_progress(server):
def hook(value, total):
def hook(value, total, preview_image_bytes):
server.send_sync("progress", { "value": value, "max": total}, server.client_id)
if preview_image_bytes is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook)

def cleanup_temp():
Expand Down
Empty file.
104 changes: 101 additions & 3 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import traceback
import math
import time
import struct
from io import BytesIO

from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
Expand All @@ -22,6 +24,8 @@
import comfy.sample
import comfy.sd
import comfy.utils
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD

import comfy.clip_vision

Expand All @@ -31,13 +35,40 @@
import folder_paths


class LatentPreviewer:
def decode_latent_to_preview(self, device, x0):
pass


class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self):
self.latent_rgb_factors = torch.tensor([
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
], device="cpu")

def decode_latent_to_preview(self, device, x0):
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors

latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()

return Image.fromarray(latents_ubyte.numpy())


def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()

def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value)

MAX_RESOLUTION=8192
MAX_PREVIEW_RESOLUTION = 512

class CLIPTextEncode:
@classmethod
Expand Down Expand Up @@ -248,6 +279,21 @@ def encode(self, vae, pixels, mask, grow_mask_by=6):

return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )

class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
self.taesd = taesd

def decode_latent_to_preview(self, device, x0):
x_sample = self.taesd.decoder(x0.to(device))[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)

x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)

preview_image = Image.fromarray(x_sample)
return preview_image

class SaveLatent:
def __init__(self):
Expand Down Expand Up @@ -931,6 +977,26 @@ def set_mask(self, samples, mask):
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,)


def decode_latent_to_preview_image(previewer, device, preview_format, x0):
preview_image = previewer.decode_latent_to_preview(device, x0)
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS)

preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2

bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format)
preview_bytes = bytesIO.getvalue()

return preview_bytes


def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
Expand All @@ -945,9 +1011,39 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]

preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"

previewer = None
if not args.disable_previews:
# TODO previewer methods
taesd_encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth")
taesd_decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth")

method = args.default_preview_method

if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
if taesd_encoder_path and taesd_encoder_path:
method = LatentPreviewMethod.TAESD

if method == LatentPreviewMethod.TAESD:
if taesd_encoder_path and taesd_encoder_path:
taesd = TAESD(taesd_encoder_path, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth")

if previewer is None:
previewer = Latent2RGBPreviewer()

pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
pbar.update_absolute(step + 1, total_steps)
preview_bytes = None
if previewer:
preview_bytes = decode_latent_to_preview_image(previewer, device, preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)

samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
Expand All @@ -970,7 +1066,8 @@ def INPUT_TYPES(s):
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
}
}

RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
Expand All @@ -997,7 +1094,8 @@ def INPUT_TYPES(s):
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"return_with_leftover_noise": (["disable", "enable"], ),
}}
}
}

RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
Expand Down
Loading