Skip to content

Commit b5958e6

Browse files
Add comprehensive Flux2 LoKR adapter support with dual conversion paths
- BFL format: remap keys + split fused QKV via Kronecker re-factorization (Van Loan) - LyCORIS format: decode underscore-encoded paths to diffusers module names - Diffusers native format: add transformer. prefix and bake alpha - Generic lossy path: _convert_adapter_to_lora utility wrapping peft.convert_to_lora - Fix alpha handling for lora_down/lora_up format checkpoints
1 parent 8cb0b7b commit b5958e6

File tree

3 files changed

+391
-32
lines changed

3 files changed

+391
-32
lines changed

benchmark_lokr.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Benchmark: Lossless LoKR vs Lossy LoRA-via-SVD on Flux2 Klein 9B.
2+
3+
Generates images using both conversion paths for visual comparison.
4+
Uses bf16 with CPU offload.
5+
6+
Usage:
7+
python benchmark_lokr.py
8+
python benchmark_lokr.py --lokr-path "puttmorbidly233/lora" --lokr-name "klein_snofs_v1_2.safetensors"
9+
python benchmark_lokr.py --prompt "a portrait in besch art style" --ranks 32 64 128
10+
"""
11+
12+
import argparse
13+
import gc
14+
import os
15+
import time
16+
17+
import torch
18+
from diffusers import Flux2KleinPipeline
19+
from peft import convert_to_lora
20+
21+
MODEL_ID = "black-forest-labs/FLUX.2-klein-9B"
22+
DEFAULT_LOKR_PATH = "gattaplayer/besch-flux2-klein-9b-lokr-lion-3e-6-bs2-ga2-v02"
23+
OUTPUT_DIR = "benchmark_output"
24+
25+
26+
def load_pipeline():
27+
"""Load Flux2 Klein 9B in bf16 with model CPU offload."""
28+
pipe = Flux2KleinPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
29+
pipe.enable_model_cpu_offload()
30+
return pipe
31+
32+
33+
def generate(pipe, prompt, seed, num_steps=4, guidance_scale=1.0):
34+
"""Generate a single image with fixed seed for reproducibility."""
35+
generator = torch.Generator(device="cpu").manual_seed(seed)
36+
image = pipe(
37+
prompt=prompt,
38+
num_inference_steps=num_steps,
39+
guidance_scale=guidance_scale,
40+
generator=generator,
41+
height=1024,
42+
width=1024,
43+
).images[0]
44+
return image
45+
46+
47+
def benchmark_lossless(pipe, prompt, seed, lokr_path, lokr_name):
48+
"""Path A: Load LoKR natively (lossless)."""
49+
print("\n=== Path A: Lossless LoKR ===")
50+
t0 = time.time()
51+
kwargs = {"weight_name": lokr_name} if lokr_name else {}
52+
pipe.load_lora_weights(lokr_path, **kwargs)
53+
print(f" Loaded in {time.time() - t0:.1f}s")
54+
55+
t0 = time.time()
56+
image = generate(pipe, prompt, seed)
57+
print(f" Generated in {time.time() - t0:.1f}s")
58+
59+
pipe.unload_lora_weights()
60+
return image
61+
62+
63+
def benchmark_lossy(pipe, prompt, seed, rank, lokr_path, lokr_name):
64+
"""Path B: Load LoKR, convert to LoRA via SVD (lossy)."""
65+
print(f"\n=== Path B: Lossy LoRA via SVD (rank={rank}) ===")
66+
t0 = time.time()
67+
kwargs = {"weight_name": lokr_name} if lokr_name else {}
68+
pipe.load_lora_weights(lokr_path, **kwargs)
69+
load_time = time.time() - t0
70+
71+
# Detect the actual adapter name assigned by peft
72+
adapter_name = next(iter(pipe.transformer.peft_config.keys()))
73+
print(f" Adapter name: {adapter_name}")
74+
75+
t0 = time.time()
76+
lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True)
77+
convert_time = time.time() - t0
78+
print(f" Loaded LoKR in {load_time:.1f}s, converted to LoRA in {convert_time:.1f}s")
79+
80+
# Replace LoKR adapter with converted LoRA
81+
from peft import inject_adapter_in_model, set_peft_model_state_dict
82+
83+
pipe.transformer.delete_adapters(adapter_name)
84+
inject_adapter_in_model(lora_config, pipe.transformer, adapter_name=adapter_name)
85+
set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name)
86+
87+
t0 = time.time()
88+
image = generate(pipe, prompt, seed)
89+
print(f" Generated in {time.time() - t0:.1f}s")
90+
91+
pipe.unload_lora_weights()
92+
return image
93+
94+
95+
def benchmark_baseline(pipe, prompt, seed):
96+
"""Baseline: No adapter."""
97+
print("\n=== Baseline: No adapter ===")
98+
t0 = time.time()
99+
image = generate(pipe, prompt, seed)
100+
print(f" Generated in {time.time() - t0:.1f}s")
101+
return image
102+
103+
104+
def main():
105+
parser = argparse.ArgumentParser(description="Benchmark LoKR vs LoRA-via-SVD")
106+
parser.add_argument("--prompt", default="a portrait painting in besch art style")
107+
parser.add_argument("--lokr-path", default=DEFAULT_LOKR_PATH, help="HF repo or local path to LoKR checkpoint")
108+
parser.add_argument("--lokr-name", default=None, help="Filename within HF repo (if multi-file)")
109+
parser.add_argument("--seed", type=int, default=42)
110+
parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128])
111+
parser.add_argument("--skip-baseline", action="store_true")
112+
parser.add_argument("--skip-lossy", action="store_true")
113+
args = parser.parse_args()
114+
115+
os.makedirs(OUTPUT_DIR, exist_ok=True)
116+
117+
print(f"Model: {MODEL_ID}")
118+
print(f"LoKR: {args.lokr_path}" + (f" ({args.lokr_name})" if args.lokr_name else ""))
119+
print(f"Prompt: {args.prompt}")
120+
print(f"Seed: {args.seed}")
121+
if not args.skip_lossy:
122+
print(f"SVD ranks to test: {args.ranks}")
123+
124+
print("\nLoading pipeline (bf16, model CPU offload)...")
125+
pipe = load_pipeline()
126+
127+
# Baseline
128+
if not args.skip_baseline:
129+
img = benchmark_baseline(pipe, args.prompt, args.seed)
130+
path = os.path.join(OUTPUT_DIR, "baseline.png")
131+
img.save(path)
132+
print(f" Saved: {path}")
133+
134+
# Path A: Lossless LoKR
135+
img = benchmark_lossless(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name)
136+
path = os.path.join(OUTPUT_DIR, "lokr_lossless.png")
137+
img.save(path)
138+
print(f" Saved: {path}")
139+
140+
gc.collect()
141+
torch.cuda.empty_cache()
142+
143+
# Path B: Lossy LoRA via SVD at various ranks
144+
if not args.skip_lossy:
145+
for rank in args.ranks:
146+
img = benchmark_lossy(pipe, args.prompt, args.seed, rank, args.lokr_path, args.lokr_name)
147+
path = os.path.join(OUTPUT_DIR, f"lora_svd_rank{rank}.png")
148+
img.save(path)
149+
print(f" Saved: {path}")
150+
151+
gc.collect()
152+
torch.cuda.empty_cache()
153+
154+
print(f"\nAll results saved to {OUTPUT_DIR}/")
155+
print("Compare: baseline.png vs lokr_lossless.png vs lora_svd_rank*.png")
156+
157+
158+
if __name__ == "__main__":
159+
main()

0 commit comments

Comments
 (0)