A cleaned project layout for test-time orthogonal noise optimization on SDXL Turbo, now extended with a latent-matched SFT workflow for one-step generators.
Besides the original reward-driven test-time optimization backends, this repo now includes a practical version of the idea:
direct one-step SFT on high-quality
(image, prompt)pairs is often unstable, so first assign the most compatible latent noise to each target image, then fine-tune on those matched noises instead of random noises.
That gives you a cleaner supervision signal and usually behaves much better than naive random-noise SFT.
uv sync
The project now has three connected pieces for the new idea:
-
Single-sample assigned-noise inversion
Given one target image and one prompt, optimize a patch-wise orthogonal noise transform and save the best matched latent noise. -
Assigned-noise dataset builder
Given apairs.jsonl, build a dataset where each sample contains:- prompt
- target image
- best matched input noise
-
Latent-matched SFT trainer
Fine-tune the one-step SDXL Turbo model on those matched noises, with optional preserve/distillation loss so the model does not forget its original random-noise behavior too aggressively.
There is also an EM-style alternating loop:
- E-step: reassign matched noises with the current model
- M-step: fine-tune on the updated matched-noise dataset
src/ttt_reward_models/
adapters.py
cli_sdxl_reward.py
cli_pickscore.py
cli_imagereward.py
cli_hpsv2.py
cli_noise_theory.py
cli_assign_sdxl_sft.py
cli_build_assigned_noise_dataset.py
cli_train_latent_matched_sft.py
cli_em_latent_matched_sft.py
data.py
diagnostics.py
downloaders.py
paths.py
pipeline.py
rewards_clip.py
rewards_pickscore.py
rewards_imagereward.py
rewards_hpsv2.py
runners.py
runners_sft.py
utils.py
legacy/
legacy_reference/
scripts/
examples/
third_party_weights/
models/
pip install -e .Or:
pip install -r requirements.txtbash scripts/download_models.shThat script downloads the main assets into a project-local layout like:
models/
sdxl-turbo/
Hyper-SD15-1step/
PickScore_v1/
ImageReward/
HPSv2/
CLIP-ViT-L-14/
Aesthetic/
sac+logos+ava1-l14-linearMSE.pth
Common selective downloads:
bash scripts/download_models.sh --only CLIP-ViT-L-14 --only Aesthetic
bash scripts/download_models.sh --only ImageReward
bash scripts/download_models.sh --only HPSv2 --hps-version v2.1The code now prefers models/ by default. Older checkpoints under third_party_weights/ are still recognized as a backward-compatible fallback.
CLIP and aesthetic runs expect the local CLIP snapshot and, for aesthetic / hybrid, the MLP checkpoint too.
If you used bash scripts/download_models.sh, the project-local defaults below will line up with the downloaded files.
python -m ttt_reward_models.cli_sdxl_reward \
--prompt "a cinematic portrait of a girl in soft light, highly detailed" \
--reward_type clip \
--model_id ./models/sdxl-turbo \
--clip_local_dir ./models/CLIP-ViT-L-14 \
--output_dir outputs/test_time_oft_noise_clip \
--patch_size 8python -m ttt_reward_models.cli_sdxl_reward \
--prompt "a cinematic portrait of a girl in soft light, highly detailed" \
--reward_type hybrid \
--model_id ./models/sdxl-turbo \
--clip_local_dir ./models/CLIP-ViT-L-14 \
--aesthetic_ckpt ./models/Aesthetic/sac+logos+ava1-l14-linearMSE.pth \
--output_dir outputs/test_time_oft_noise_clip \
--patch_size 8
python -m ttt_reward_models.cli_pickscore \
--prompt "a cinematic portrait of a girl in soft light, highly detailed" \
--model_id ./models/sdxl-turbo \
--pickscore_model_id ./models/PickScore_v1python -m ttt_reward_models.cli_imagereward \
--prompt "a cinematic portrait of a girl in soft light, highly detailed" \
--model_id ./models/sdxl-turbo \
--imagereward_model_path ./models/ImageReward/ImageReward.pt \
--imagereward_med_config_path ./models/ImageReward/med_config.jsonpython -m ttt_reward_models.cli_hpsv2 \
--prompt "a cinematic portrait of a girl in soft light, highly detailed" \
--model_id ./models/sdxl-turbo \
--hps_version v2.1 \
--hps_checkpoint_path ./models/HPSv2/HPS_v2.1_compressed.ptEvery reward CLI and assigned-noise inversion run writes diagnostics such as:
orthogonal_gaussian_init.png / .json
orthogonal_gaussian_final.png / .json
These compare the original standard Gaussian latent noise with the orthogonally transformed noise Q @ z. For a fixed orthogonal map, the Gaussian statistics are preserved up to numerical error. In the learned/data-dependent assignment setting, the geometry is still constrained to an orthogonal orbit, but the collected assigned latents should not be described as i.i.d. standard Gaussian without extra qualification.
Standalone verification:
python -m ttt_reward_models.cli_noise_theory \
--output_dir outputs/orthogonal_gaussian_theory \
--channels 4 \
--patch_size 2 \
--num_samples 65536 \
--batch_size 4096Create a dataset folder like this:
data/high_quality_pairs/
pairs.jsonl
images/
0001.png
0002.png
Where each line in pairs.jsonl looks like:
{"prompt": "a cinematic portrait of a girl in soft light", "image": "images/0001.png"}You can rename the fields with --prompt_field and --image_field.
python -m ttt_reward_models.cli_assign_sdxl_sft \
--prompt "a cinematic portrait of a girl in soft light" \
--target_image_path ./target.jpg \
--model_id ./models/sdxl-turbo \
--output_dir outputs/run_assign_noise_sft \
--steps 40 \
--lr 5e-4 \
--patch_size 2 \
--latent_loss_weight 1.0 \
--pixel_l1_weight 0.1This saves best_input_noise.pt, intermediate images, loss curves, and the orthogonal-Gaussian diagnostics.
python -m ttt_reward_models.cli_build_assigned_noise_dataset \
--data_root ./data/high_quality_pairs \
--output_root ./outputs/assigned_noise_dataset \
--model_id ./models/sdxl-turbo \
--steps 40 \
--lr 5e-4 \
--patch_size 2 \
--latent_loss_weight 1.0 \
--pixel_l1_weight 0.1This writes:
outputs/assigned_noise_dataset/
manifest.jsonl
failures.jsonl
summary.json
samples/
000000_xxx/
best_input_noise.pt
meta.json
original_or_linked_target.png
run_outputs/
python -m ttt_reward_models.cli_train_latent_matched_sft \
--manifest_path ./outputs/assigned_noise_dataset/manifest.jsonl \
--model_id ./models/sdxl-turbo \
--output_dir ./outputs/latent_matched_sft \
--epochs 2 \
--batch_size 1 \
--lr 1e-5 \
--latent_loss_weight 1.0 \
--pixel_l1_weight 0.1 \
--preserve_latent_weight 0.5 \
--preserve_pixel_weight 0.05The preserve losses are important in practice. They encourage the updated model to stay close to the base model on random-noise generations, which helps reduce catastrophic drift and keeps the model more usable off-distribution.
Trainer outputs:
outputs/latent_matched_sft/
checkpoint-best/
checkpoint-final/
checkpoint-epoch-001/
train_history.png
train_summary.json
The saved checkpoints are full diffusers pipelines, so you can reuse them as the next --model_id.
python -m ttt_reward_models.cli_em_latent_matched_sft \
--data_root ./data/high_quality_pairs \
--work_dir ./outputs/em_latent_matched_sft \
--initial_model_id ./models/sdxl-turbo \
--num_cycles 2 \
--assign_steps 40 \
--assign_lr 5e-4 \
--train_epochs 1 \
--train_batch_size 1 \
--train_lr 1e-5 \
--latent_loss_weight 1.0 \
--pixel_l1_weight 0.1 \
--preserve_latent_weight 0.5 \
--preserve_pixel_weight 0.05This implements the basic hard-EM style loop:
- assign better noises with the current model
- fine-tune on those noises
- repeat
- The assigned-noise stage uses the patch-wise orthogonal noise transform, so the optimized noise stays tied to a Gaussian source in a structured way rather than becoming a totally unconstrained per-sample latent code.
- The trainer is intentionally conservative: it only trains the UNet, while keeping the VAE and text encoders frozen.
- The current clean implementation focuses on SDXL Turbo. Your uploaded
test_time_inference.zipreference scripts are preserved underlegacy_reference/for comparison, including the SD1.5 Hyper-SD variants and the older dataset-building code. - The trainer here is meant as a stable baseline implementation of the idea, not the final word. Natural next steps would be LoRA-only tuning, EMA, distributed training, and explicit Gaussian-prior regularizers for more flexible per-sample latent parameterizations.
bash scripts/run_assign_noise_sft_example.sh
bash scripts/run_build_assigned_dataset_example.sh
bash scripts/run_latent_matched_sft_example.sh
bash scripts/run_em_latent_matched_sft_example.shI also added a lightweight MNIST version of the same idea so you can debug the full workflow quickly before moving back to SD/SDXL.
python -m ttt_reward_models.cli_train_mnist_gan- trains a one-step class-conditional MNIST generator (GAN-style baseline)
python -m ttt_reward_models.cli_build_mnist_assigned_noise_dataset- for each target digit image, optimizes a matched latent noise
z*
- for each target digit image, optimizes a matched latent noise
python -m ttt_reward_models.cli_train_mnist_latent_matched_sft- fine-tunes the generator on the assigned noises
python -m ttt_reward_models.cli_train_mnist_direct_sft- direct random-noise SFT baseline for comparison
python -m ttt_reward_models.cli_em_mnist_latent_matched_sft- EM-style alternation: assign -> SFT -> reassign
python -m ttt_reward_models.cli_train_mnist_gan --output_dir outputs/mnist_gan --epochs 10python -m ttt_reward_models.cli_build_mnist_assigned_noise_dataset --generator_ckpt outputs/mnist_gan/mnist_gan_final.pt --output_dir outputs/mnist_assigned_dataset --max_items 256 --assign_steps 200python -m ttt_reward_models.cli_train_mnist_latent_matched_sft --generator_ckpt outputs/mnist_gan/mnist_gan_final.pt --manifest_path outputs/mnist_assigned_dataset/manifest.jsonl --output_dir outputs/mnist_latent_matched_sft --epochs 5 --preserve_weight 0.25For a baseline comparison:
python -m ttt_reward_models.cli_train_mnist_direct_sft --generator_ckpt outputs/mnist_gan/mnist_gan_final.pt --output_dir outputs/mnist_direct_sft --epochs 5- The MNIST demo uses digit labels as the condition instead of text prompts.
- Assigned-noise records are stored in a
manifest.jsonlplus per-samplenoise.ptfiles. - The latent-matched SFT stage includes a preserve loss against the frozen teacher generator.
- This MNIST branch is meant as a fast algorithmic sanity-check, not as a replacement for the SDXL workflow.
The repo also includes two lightweight toy branches so you can test the algorithmic idea without SDXL:
- MNIST: grayscale 28x28 one-step conditional generator
- CIFAR-10: RGB 32x32 one-step conditional generator
Both branches support the same pattern:
- train a base one-step conditional generator
- assign a matched latent
z*for each target image - compare direct random-noise SFT vs latent-matched SFT
- optionally run an EM-style loop
python -m ttt_reward_models.cli_train_mnist_gan --output_dir outputs/mnist_gan --epochs 10
python -m ttt_reward_models.cli_build_mnist_assigned_noise_dataset --generator_ckpt outputs/mnist_gan/mnist_gan_final.pt --output_dir outputs/mnist_assigned_dataset --max_items 256 --assign_steps 200
python -m ttt_reward_models.cli_train_mnist_latent_matched_sft --generator_ckpt outputs/mnist_gan/mnist_gan_final.pt --manifest_path outputs/mnist_assigned_dataset/manifest.jsonl --output_dir outputs/mnist_latent_matched_sft --epochs 5 --preserve_weight 0.25python -m ttt_reward_models.cli_train_cifar_gan --output_dir outputs/cifar_gan --epochs 20
python -m ttt_reward_models.cli_build_cifar_assigned_noise_dataset --generator_ckpt outputs/cifar_gan/cifar_gan_final.pt --output_dir outputs/cifar_assigned_dataset --max_items 512 --assign_steps 300
python -m ttt_reward_models.cli_train_cifar_latent_matched_sft --generator_ckpt outputs/cifar_gan/cifar_gan_final.pt --manifest_path outputs/cifar_assigned_dataset/manifest.jsonl --output_dir outputs/cifar_latent_matched_sft --epochs 5 --preserve_weight 0.25CIFAR also has:
python -m ttt_reward_models.cli_train_cifar_direct_sftpython -m ttt_reward_models.cli_em_cifar_latent_matched_sft
and matching example scripts under scripts/run_cifar_*.sh.
The toy branches now use orthogonal latent assignment as well:
- sample base noise
eps ~ N(0, I) - optimize an orthogonal matrix
Q - use
z = Q epsfor matched-noise assignment
By default, the assigned-noise builders now operate on the training split. Use --test_split only when you explicitly want a held-out/demo assignment set.
If lose the follow material:
.../site-packages/hpsv2/src/open_clip/bpe_simple_vocab_16e6.txt.gz
cp \
./.venv/lib/python3.13/site-packages/open_clip/bpe_simple_vocab_16e6.txt.gz \
./.venv/lib/python3.13/site-packages/hpsv2/src/open_clip/