Skip to content

UGVly/Orthogonal-Constrained-One-Step-Alignment

Repository files navigation

Orthogonal-Constrained One-Step Alignment

A cleaned project layout for test-time orthogonal noise optimization on SDXL Turbo, now extended with a latent-matched SFT workflow for one-step generators.

Besides the original reward-driven test-time optimization backends, this repo now includes a practical version of the idea:

direct one-step SFT on high-quality (image, prompt) pairs is often unstable, so first assign the most compatible latent noise to each target image, then fine-tune on those matched noises instead of random noises.

That gives you a cleaner supervision signal and usually behaves much better than naive random-noise SFT.

Setup

uv sync

What is new in this version

The project now has three connected pieces for the new idea:

  1. Single-sample assigned-noise inversion
    Given one target image and one prompt, optimize a patch-wise orthogonal noise transform and save the best matched latent noise.

  2. Assigned-noise dataset builder
    Given a pairs.jsonl, build a dataset where each sample contains:

    • prompt
    • target image
    • best matched input noise
  3. Latent-matched SFT trainer
    Fine-tune the one-step SDXL Turbo model on those matched noises, with optional preserve/distillation loss so the model does not forget its original random-noise behavior too aggressively.

There is also an EM-style alternating loop:

  • E-step: reassign matched noises with the current model
  • M-step: fine-tune on the updated matched-noise dataset

Project layout

src/ttt_reward_models/
  adapters.py
  cli_sdxl_reward.py
  cli_pickscore.py
  cli_imagereward.py
  cli_hpsv2.py
  cli_noise_theory.py
  cli_assign_sdxl_sft.py
  cli_build_assigned_noise_dataset.py
  cli_train_latent_matched_sft.py
  cli_em_latent_matched_sft.py
  data.py
  diagnostics.py
  downloaders.py
  paths.py
  pipeline.py
  rewards_clip.py
  rewards_pickscore.py
  rewards_imagereward.py
  rewards_hpsv2.py
  runners.py
  runners_sft.py
  utils.py
legacy/
legacy_reference/
scripts/
examples/
third_party_weights/
models/

Install

pip install -e .

Or:

pip install -r requirements.txt

Download models into the project-local models/ folder

bash scripts/download_models.sh

That script downloads the main assets into a project-local layout like:

models/
  sdxl-turbo/
  Hyper-SD15-1step/
  PickScore_v1/
  ImageReward/
  HPSv2/
  CLIP-ViT-L-14/
  Aesthetic/
    sac+logos+ava1-l14-linearMSE.pth

Common selective downloads:

bash scripts/download_models.sh --only CLIP-ViT-L-14 --only Aesthetic
bash scripts/download_models.sh --only ImageReward
bash scripts/download_models.sh --only HPSv2 --hps-version v2.1

The code now prefers models/ by default. Older checkpoints under third_party_weights/ are still recognized as a backward-compatible fallback.

Original reward-driven one-step optimization

CLIP / Aesthetic / Hybrid

CLIP and aesthetic runs expect the local CLIP snapshot and, for aesthetic / hybrid, the MLP checkpoint too. If you used bash scripts/download_models.sh, the project-local defaults below will line up with the downloaded files.

python -m ttt_reward_models.cli_sdxl_reward \
  --prompt "a cinematic portrait of a girl in soft light, highly detailed" \
  --reward_type clip \
  --model_id ./models/sdxl-turbo \
  --clip_local_dir ./models/CLIP-ViT-L-14 \
  --output_dir outputs/test_time_oft_noise_clip \
  --patch_size 8
python -m ttt_reward_models.cli_sdxl_reward \
  --prompt "a cinematic portrait of a girl in soft light, highly detailed" \
  --reward_type hybrid \
  --model_id ./models/sdxl-turbo \
  --clip_local_dir ./models/CLIP-ViT-L-14 \
  --aesthetic_ckpt ./models/Aesthetic/sac+logos+ava1-l14-linearMSE.pth \
  --output_dir outputs/test_time_oft_noise_clip \
  --patch_size 8

PickScore

python -m ttt_reward_models.cli_pickscore \
  --prompt "a cinematic portrait of a girl in soft light, highly detailed" \
  --model_id ./models/sdxl-turbo \
  --pickscore_model_id ./models/PickScore_v1

ImageReward

python -m ttt_reward_models.cli_imagereward \
  --prompt "a cinematic portrait of a girl in soft light, highly detailed" \
  --model_id ./models/sdxl-turbo \
  --imagereward_model_path ./models/ImageReward/ImageReward.pt \
  --imagereward_med_config_path ./models/ImageReward/med_config.json

HPSv2

python -m ttt_reward_models.cli_hpsv2 \
  --prompt "a cinematic portrait of a girl in soft light, highly detailed" \
  --model_id ./models/sdxl-turbo \
  --hps_version v2.1 \
  --hps_checkpoint_path ./models/HPSv2/HPS_v2.1_compressed.pt

Noise-theory visualization

Every reward CLI and assigned-noise inversion run writes diagnostics such as:

orthogonal_gaussian_init.png / .json
orthogonal_gaussian_final.png / .json

These compare the original standard Gaussian latent noise with the orthogonally transformed noise Q @ z. For a fixed orthogonal map, the Gaussian statistics are preserved up to numerical error. In the learned/data-dependent assignment setting, the geometry is still constrained to an orthogonal orbit, but the collected assigned latents should not be described as i.i.d. standard Gaussian without extra qualification.

Standalone verification:

python -m ttt_reward_models.cli_noise_theory \
  --output_dir outputs/orthogonal_gaussian_theory \
  --channels 4 \
  --patch_size 2 \
  --num_samples 65536 \
  --batch_size 4096

New workflow: latent-matched SFT

Expected data format

Create a dataset folder like this:

data/high_quality_pairs/
  pairs.jsonl
  images/
    0001.png
    0002.png

Where each line in pairs.jsonl looks like:

{"prompt": "a cinematic portrait of a girl in soft light", "image": "images/0001.png"}

You can rename the fields with --prompt_field and --image_field.

Step 1: assign the best noise for a single sample

python -m ttt_reward_models.cli_assign_sdxl_sft \
  --prompt "a cinematic portrait of a girl in soft light" \
  --target_image_path ./target.jpg \
  --model_id ./models/sdxl-turbo \
  --output_dir outputs/run_assign_noise_sft \
  --steps 40 \
  --lr 5e-4 \
  --patch_size 2 \
  --latent_loss_weight 1.0 \
  --pixel_l1_weight 0.1

This saves best_input_noise.pt, intermediate images, loss curves, and the orthogonal-Gaussian diagnostics.

Step 2: build a matched-noise dataset from many samples

python -m ttt_reward_models.cli_build_assigned_noise_dataset \
  --data_root ./data/high_quality_pairs \
  --output_root ./outputs/assigned_noise_dataset \
  --model_id ./models/sdxl-turbo \
  --steps 40 \
  --lr 5e-4 \
  --patch_size 2 \
  --latent_loss_weight 1.0 \
  --pixel_l1_weight 0.1

This writes:

outputs/assigned_noise_dataset/
  manifest.jsonl
  failures.jsonl
  summary.json
  samples/
    000000_xxx/
      best_input_noise.pt
      meta.json
      original_or_linked_target.png
      run_outputs/

Step 3: fine-tune on matched noises instead of random noises

python -m ttt_reward_models.cli_train_latent_matched_sft \
  --manifest_path ./outputs/assigned_noise_dataset/manifest.jsonl \
  --model_id ./models/sdxl-turbo \
  --output_dir ./outputs/latent_matched_sft \
  --epochs 2 \
  --batch_size 1 \
  --lr 1e-5 \
  --latent_loss_weight 1.0 \
  --pixel_l1_weight 0.1 \
  --preserve_latent_weight 0.5 \
  --preserve_pixel_weight 0.05

The preserve losses are important in practice. They encourage the updated model to stay close to the base model on random-noise generations, which helps reduce catastrophic drift and keeps the model more usable off-distribution.

Trainer outputs:

outputs/latent_matched_sft/
  checkpoint-best/
  checkpoint-final/
  checkpoint-epoch-001/
  train_history.png
  train_summary.json

The saved checkpoints are full diffusers pipelines, so you can reuse them as the next --model_id.

Step 4: alternating E/M cycles

python -m ttt_reward_models.cli_em_latent_matched_sft \
  --data_root ./data/high_quality_pairs \
  --work_dir ./outputs/em_latent_matched_sft \
  --initial_model_id ./models/sdxl-turbo \
  --num_cycles 2 \
  --assign_steps 40 \
  --assign_lr 5e-4 \
  --train_epochs 1 \
  --train_batch_size 1 \
  --train_lr 1e-5 \
  --latent_loss_weight 1.0 \
  --pixel_l1_weight 0.1 \
  --preserve_latent_weight 0.5 \
  --preserve_pixel_weight 0.05

This implements the basic hard-EM style loop:

  • assign better noises with the current model
  • fine-tune on those noises
  • repeat

Practical notes

  • The assigned-noise stage uses the patch-wise orthogonal noise transform, so the optimized noise stays tied to a Gaussian source in a structured way rather than becoming a totally unconstrained per-sample latent code.
  • The trainer is intentionally conservative: it only trains the UNet, while keeping the VAE and text encoders frozen.
  • The current clean implementation focuses on SDXL Turbo. Your uploaded test_time_inference.zip reference scripts are preserved under legacy_reference/ for comparison, including the SD1.5 Hyper-SD variants and the older dataset-building code.
  • The trainer here is meant as a stable baseline implementation of the idea, not the final word. Natural next steps would be LoRA-only tuning, EMA, distributed training, and explicit Gaussian-prior regularizers for more flexible per-sample latent parameterizations.

Example helper scripts

bash scripts/run_assign_noise_sft_example.sh
bash scripts/run_build_assigned_dataset_example.sh
bash scripts/run_latent_matched_sft_example.sh
bash scripts/run_em_latent_matched_sft_example.sh

MNIST latent-matched SFT demo

I also added a lightweight MNIST version of the same idea so you can debug the full workflow quickly before moving back to SD/SDXL.

What is included

  • python -m ttt_reward_models.cli_train_mnist_gan
    • trains a one-step class-conditional MNIST generator (GAN-style baseline)
  • python -m ttt_reward_models.cli_build_mnist_assigned_noise_dataset
    • for each target digit image, optimizes a matched latent noise z*
  • python -m ttt_reward_models.cli_train_mnist_latent_matched_sft
    • fine-tunes the generator on the assigned noises
  • python -m ttt_reward_models.cli_train_mnist_direct_sft
    • direct random-noise SFT baseline for comparison
  • python -m ttt_reward_models.cli_em_mnist_latent_matched_sft
    • EM-style alternation: assign -> SFT -> reassign

Recommended MNIST workflow

python -m ttt_reward_models.cli_train_mnist_gan   --output_dir outputs/mnist_gan   --epochs 10
python -m ttt_reward_models.cli_build_mnist_assigned_noise_dataset   --generator_ckpt outputs/mnist_gan/mnist_gan_final.pt   --output_dir outputs/mnist_assigned_dataset   --max_items 256   --assign_steps 200
python -m ttt_reward_models.cli_train_mnist_latent_matched_sft   --generator_ckpt outputs/mnist_gan/mnist_gan_final.pt   --manifest_path outputs/mnist_assigned_dataset/manifest.jsonl   --output_dir outputs/mnist_latent_matched_sft   --epochs 5   --preserve_weight 0.25

For a baseline comparison:

python -m ttt_reward_models.cli_train_mnist_direct_sft   --generator_ckpt outputs/mnist_gan/mnist_gan_final.pt   --output_dir outputs/mnist_direct_sft   --epochs 5

Notes

  • The MNIST demo uses digit labels as the condition instead of text prompts.
  • Assigned-noise records are stored in a manifest.jsonl plus per-sample noise.pt files.
  • The latent-matched SFT stage includes a preserve loss against the frozen teacher generator.
  • This MNIST branch is meant as a fast algorithmic sanity-check, not as a replacement for the SDXL workflow.

Toy pipeline: MNIST and CIFAR-10 latent-matched SFT

The repo also includes two lightweight toy branches so you can test the algorithmic idea without SDXL:

  • MNIST: grayscale 28x28 one-step conditional generator
  • CIFAR-10: RGB 32x32 one-step conditional generator

Both branches support the same pattern:

  1. train a base one-step conditional generator
  2. assign a matched latent z* for each target image
  3. compare direct random-noise SFT vs latent-matched SFT
  4. optionally run an EM-style loop

MNIST quick start

python -m ttt_reward_models.cli_train_mnist_gan --output_dir outputs/mnist_gan --epochs 10
python -m ttt_reward_models.cli_build_mnist_assigned_noise_dataset --generator_ckpt outputs/mnist_gan/mnist_gan_final.pt --output_dir outputs/mnist_assigned_dataset --max_items 256 --assign_steps 200
python -m ttt_reward_models.cli_train_mnist_latent_matched_sft --generator_ckpt outputs/mnist_gan/mnist_gan_final.pt --manifest_path outputs/mnist_assigned_dataset/manifest.jsonl --output_dir outputs/mnist_latent_matched_sft --epochs 5 --preserve_weight 0.25

CIFAR-10 quick start

python -m ttt_reward_models.cli_train_cifar_gan --output_dir outputs/cifar_gan --epochs 20
python -m ttt_reward_models.cli_build_cifar_assigned_noise_dataset --generator_ckpt outputs/cifar_gan/cifar_gan_final.pt --output_dir outputs/cifar_assigned_dataset --max_items 512 --assign_steps 300
python -m ttt_reward_models.cli_train_cifar_latent_matched_sft --generator_ckpt outputs/cifar_gan/cifar_gan_final.pt --manifest_path outputs/cifar_assigned_dataset/manifest.jsonl --output_dir outputs/cifar_latent_matched_sft --epochs 5 --preserve_weight 0.25

CIFAR also has:

  • python -m ttt_reward_models.cli_train_cifar_direct_sft
  • python -m ttt_reward_models.cli_em_cifar_latent_matched_sft

and matching example scripts under scripts/run_cifar_*.sh.

MNIST / CIFAR toy branches

The toy branches now use orthogonal latent assignment as well:

  • sample base noise eps ~ N(0, I)
  • optimize an orthogonal matrix Q
  • use z = Q eps for matched-noise assignment

By default, the assigned-noise builders now operate on the training split. Use --test_split only when you explicitly want a held-out/demo assignment set.

Debug

If lose the follow material: .../site-packages/hpsv2/src/open_clip/bpe_simple_vocab_16e6.txt.gz

cp \
./.venv/lib/python3.13/site-packages/open_clip/bpe_simple_vocab_16e6.txt.gz \
./.venv/lib/python3.13/site-packages/hpsv2/src/open_clip/

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors