|
| 1 | +"""Benchmark: Lossless LoKR vs Lossy LoRA-via-SVD on Flux2 Klein 9B. |
| 2 | +
|
| 3 | +Generates images using both conversion paths for visual comparison. |
| 4 | +Uses bf16 with CPU offload. |
| 5 | +
|
| 6 | +Usage: |
| 7 | + python benchmark_lokr.py |
| 8 | + python benchmark_lokr.py --lokr-path "puttmorbidly233/lora" --lokr-name "klein_snofs_v1_2.safetensors" |
| 9 | + python benchmark_lokr.py --prompt "a portrait in besch art style" --ranks 32 64 128 |
| 10 | +""" |
| 11 | + |
| 12 | +import argparse |
| 13 | +import gc |
| 14 | +import os |
| 15 | +import time |
| 16 | + |
| 17 | +import torch |
| 18 | +from diffusers import Flux2KleinPipeline |
| 19 | +from peft import convert_to_lora |
| 20 | + |
| 21 | +MODEL_ID = "black-forest-labs/FLUX.2-klein-9B" |
| 22 | +DEFAULT_LOKR_PATH = "gattaplayer/besch-flux2-klein-9b-lokr-lion-3e-6-bs2-ga2-v02" |
| 23 | +OUTPUT_DIR = "benchmark_output" |
| 24 | + |
| 25 | + |
| 26 | +def load_pipeline(): |
| 27 | + """Load Flux2 Klein 9B in bf16 with model CPU offload.""" |
| 28 | + pipe = Flux2KleinPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16) |
| 29 | + pipe.enable_model_cpu_offload() |
| 30 | + return pipe |
| 31 | + |
| 32 | + |
| 33 | +def generate(pipe, prompt, seed, num_steps=4, guidance_scale=1.0): |
| 34 | + """Generate a single image with fixed seed for reproducibility.""" |
| 35 | + generator = torch.Generator(device="cpu").manual_seed(seed) |
| 36 | + image = pipe( |
| 37 | + prompt=prompt, |
| 38 | + num_inference_steps=num_steps, |
| 39 | + guidance_scale=guidance_scale, |
| 40 | + generator=generator, |
| 41 | + height=1024, |
| 42 | + width=1024, |
| 43 | + ).images[0] |
| 44 | + return image |
| 45 | + |
| 46 | + |
| 47 | +def benchmark_lossless(pipe, prompt, seed, lokr_path, lokr_name): |
| 48 | + """Path A: Load LoKR natively (lossless).""" |
| 49 | + print("\n=== Path A: Lossless LoKR ===") |
| 50 | + t0 = time.time() |
| 51 | + kwargs = {"weight_name": lokr_name} if lokr_name else {} |
| 52 | + pipe.load_lora_weights(lokr_path, **kwargs) |
| 53 | + print(f" Loaded in {time.time() - t0:.1f}s") |
| 54 | + |
| 55 | + t0 = time.time() |
| 56 | + image = generate(pipe, prompt, seed) |
| 57 | + print(f" Generated in {time.time() - t0:.1f}s") |
| 58 | + |
| 59 | + pipe.unload_lora_weights() |
| 60 | + return image |
| 61 | + |
| 62 | + |
| 63 | +def benchmark_lossy(pipe, prompt, seed, rank, lokr_path, lokr_name): |
| 64 | + """Path B: Load LoKR, convert to LoRA via SVD (lossy).""" |
| 65 | + print(f"\n=== Path B: Lossy LoRA via SVD (rank={rank}) ===") |
| 66 | + t0 = time.time() |
| 67 | + kwargs = {"weight_name": lokr_name} if lokr_name else {} |
| 68 | + pipe.load_lora_weights(lokr_path, **kwargs) |
| 69 | + load_time = time.time() - t0 |
| 70 | + |
| 71 | + # Detect the actual adapter name assigned by peft |
| 72 | + adapter_name = next(iter(pipe.transformer.peft_config.keys())) |
| 73 | + print(f" Adapter name: {adapter_name}") |
| 74 | + |
| 75 | + t0 = time.time() |
| 76 | + lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True) |
| 77 | + convert_time = time.time() - t0 |
| 78 | + print(f" Loaded LoKR in {load_time:.1f}s, converted to LoRA in {convert_time:.1f}s") |
| 79 | + |
| 80 | + # Replace LoKR adapter with converted LoRA |
| 81 | + from peft import inject_adapter_in_model, set_peft_model_state_dict |
| 82 | + |
| 83 | + pipe.transformer.delete_adapters(adapter_name) |
| 84 | + inject_adapter_in_model(lora_config, pipe.transformer, adapter_name=adapter_name) |
| 85 | + set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name) |
| 86 | + |
| 87 | + t0 = time.time() |
| 88 | + image = generate(pipe, prompt, seed) |
| 89 | + print(f" Generated in {time.time() - t0:.1f}s") |
| 90 | + |
| 91 | + pipe.unload_lora_weights() |
| 92 | + return image |
| 93 | + |
| 94 | + |
| 95 | +def benchmark_baseline(pipe, prompt, seed): |
| 96 | + """Baseline: No adapter.""" |
| 97 | + print("\n=== Baseline: No adapter ===") |
| 98 | + t0 = time.time() |
| 99 | + image = generate(pipe, prompt, seed) |
| 100 | + print(f" Generated in {time.time() - t0:.1f}s") |
| 101 | + return image |
| 102 | + |
| 103 | + |
| 104 | +def main(): |
| 105 | + parser = argparse.ArgumentParser(description="Benchmark LoKR vs LoRA-via-SVD") |
| 106 | + parser.add_argument("--prompt", default="a portrait painting in besch art style") |
| 107 | + parser.add_argument("--lokr-path", default=DEFAULT_LOKR_PATH, help="HF repo or local path to LoKR checkpoint") |
| 108 | + parser.add_argument("--lokr-name", default=None, help="Filename within HF repo (if multi-file)") |
| 109 | + parser.add_argument("--seed", type=int, default=42) |
| 110 | + parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128]) |
| 111 | + parser.add_argument("--skip-baseline", action="store_true") |
| 112 | + parser.add_argument("--skip-lossy", action="store_true") |
| 113 | + args = parser.parse_args() |
| 114 | + |
| 115 | + os.makedirs(OUTPUT_DIR, exist_ok=True) |
| 116 | + |
| 117 | + print(f"Model: {MODEL_ID}") |
| 118 | + print(f"LoKR: {args.lokr_path}" + (f" ({args.lokr_name})" if args.lokr_name else "")) |
| 119 | + print(f"Prompt: {args.prompt}") |
| 120 | + print(f"Seed: {args.seed}") |
| 121 | + if not args.skip_lossy: |
| 122 | + print(f"SVD ranks to test: {args.ranks}") |
| 123 | + |
| 124 | + print("\nLoading pipeline (bf16, model CPU offload)...") |
| 125 | + pipe = load_pipeline() |
| 126 | + |
| 127 | + # Baseline |
| 128 | + if not args.skip_baseline: |
| 129 | + img = benchmark_baseline(pipe, args.prompt, args.seed) |
| 130 | + path = os.path.join(OUTPUT_DIR, "baseline.png") |
| 131 | + img.save(path) |
| 132 | + print(f" Saved: {path}") |
| 133 | + |
| 134 | + # Path A: Lossless LoKR |
| 135 | + img = benchmark_lossless(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name) |
| 136 | + path = os.path.join(OUTPUT_DIR, "lokr_lossless.png") |
| 137 | + img.save(path) |
| 138 | + print(f" Saved: {path}") |
| 139 | + |
| 140 | + gc.collect() |
| 141 | + torch.cuda.empty_cache() |
| 142 | + |
| 143 | + # Path B: Lossy LoRA via SVD at various ranks |
| 144 | + if not args.skip_lossy: |
| 145 | + for rank in args.ranks: |
| 146 | + img = benchmark_lossy(pipe, args.prompt, args.seed, rank, args.lokr_path, args.lokr_name) |
| 147 | + path = os.path.join(OUTPUT_DIR, f"lora_svd_rank{rank}.png") |
| 148 | + img.save(path) |
| 149 | + print(f" Saved: {path}") |
| 150 | + |
| 151 | + gc.collect() |
| 152 | + torch.cuda.empty_cache() |
| 153 | + |
| 154 | + print(f"\nAll results saved to {OUTPUT_DIR}/") |
| 155 | + print("Compare: baseline.png vs lokr_lossless.png vs lora_svd_rank*.png") |
| 156 | + |
| 157 | + |
| 158 | +if __name__ == "__main__": |
| 159 | + main() |
0 commit comments