Skip to content

Non-record: Compressor-Aware Training (CAT), differentiable compression proxies for LZ-family compressors#1385

Open
korentomas wants to merge 2 commits intoopenai:mainfrom
korentomas:cat-submission
Open

Non-record: Compressor-Aware Training (CAT), differentiable compression proxies for LZ-family compressors#1385
korentomas wants to merge 2 commits intoopenai:mainfrom
korentomas:cat-submission

Conversation

@korentomas
Copy link
Copy Markdown

Summary

  • Novel technique: differentiable proxies for LZ-family compression (dictionary matching + entropy coding) as a training regularizer
  • Training is indifferent to compression. This submission makes it aware.
  • Dictionary matching proxy approximates LZ77 via multi-lag soft autocorrelation on the serialized quantized weight byte stream
  • Entropy proxy approximates Huffman/FSE via soft histogram Shannon entropy
  • 1xH100 ablation: 6.8% artifact size reduction at +0.009 BPB cost (combined), up to 20% at higher lambda
  • No prior work trains neural network weights for LZ77-style dictionary matching

Results (1xH100, 5 runs)

Run BPB Artifact vs Control
Control 1.4374 12.32 MB --
Dict. match only 1.4463 12.15 MB -173 KB
Entropy only 1.4465 11.52 MB -808 KB
Combined 1.4465 11.48 MB -842 KB
Entropy strong 1.5044 9.81 MB -2.52 MB

Full writeup, prior art analysis, and debugging notes in the README.

Differentiable proxies for LZ-family compression as a training
regularizer. Dictionary matching proxy (multi-lag autocorrelation)
+ entropy proxy (soft histogram). 6.8% artifact size reduction
on 1xH100 at +0.009 BPB cost.
Copilot AI review requested due to automatic review settings April 5, 2026 17:18
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a non-record submission demonstrating Compressor-Aware Training (CAT): differentiable proxies for LZ-family compression (dictionary matching + entropy coding) used as a training regularizer to produce more zlib-friendly quantized weights under the 16MB artifact constraint.

Changes:

  • Introduces a new training script (train_golf.py) implementing CAT losses (LZ77-style multi-lag soft autocorrelation + soft-histogram entropy) and logging compressed-size diagnostics.
  • Adds a submission manifest (submission.json) capturing metrics and architecture/hyperparameters for the run.
  • Adds a technique writeup (README.md) describing motivation, proxy losses, and ablation results.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/train_golf.py Implements CAT regularizers + training/quantization/export pipeline for the non-record submission.
records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/submission.json Records run metadata and CAT hyperparameters for the submission.
records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/README.md Documents CAT approach, debugging notes, and experimental results.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1324 to +1342
# CAT regularizer (separate backward pass, outside torch.compile region).
# Uses base_model (not DDP wrapper) — bypasses DDP allreduce hooks.
# This is correct because: (1) DDP keeps weights synchronized across ranks,
# (2) seed=step ensures identical subsampling, so all ranks compute identical
# CAT gradients independently. If either invariant breaks, ranks will diverge.
cat_loss_val = 0.0
if args.cat_lambda_lz > 0 or args.cat_lambda_h > 0:
if step >= args.cat_start_step:
cat_device = next(base_model.parameters()).device
byte_stream = serialize_quantized_weights_torch(base_model)
cat_loss = torch.tensor(0.0, device=cat_device)
if args.cat_lambda_lz > 0:
cat_loss = cat_loss + args.cat_lambda_lz * lz77_proxy_loss_torch(
byte_stream, args.cat_temperature, args.cat_sample_size, seed=step)
if args.cat_lambda_h > 0:
cat_loss = cat_loss + args.cat_lambda_h * entropy_proxy_loss_torch(
byte_stream, args.cat_bandwidth, args.cat_sample_size, seed=step)
cat_loss.backward()
cat_loss_val = cat_loss.item()
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In distributed (DDP) runs, the CAT backward pass is executed on base_model outside the DDP wrapper, so its gradients are never all-reduced. The comment assumes all ranks will compute identical CAT gradients, but small nondeterminism (or future refactors) would cause rank divergence. Consider either computing CAT loss through the DDP-wrapped model on a synced backward step, or explicitly all-reducing the affected parameter grads after cat_loss.backward() when distributed is true.

Copilot uses AI. Check for mistakes.
Comment on lines +653 to +658
for lag in lags:
if lag >= x.shape[0]:
break
diff_sq = (x[lag:] - x[:-lag]).square()
match_score = match_score + torch.exp(-diff_sq / temperature).mean()
return -match_score / float(len(lags))
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lz77_proxy_loss_torch breaks out of the lag loop when lag >= x.shape[0], but still divides by len(lags) (10). For shorter byte_stream inputs this changes the loss scale depending on length. Track the actual number of lags accumulated and divide by that instead (or early-return when x is too short).

Suggested change
for lag in lags:
if lag >= x.shape[0]:
break
diff_sq = (x[lag:] - x[:-lag]).square()
match_score = match_score + torch.exp(-diff_sq / temperature).mean()
return -match_score / float(len(lags))
num_lags = 0
for lag in lags:
if lag >= x.shape[0]:
break
diff_sq = (x[lag:] - x[:-lag]).square()
match_score = match_score + torch.exp(-diff_sq / temperature).mean()
num_lags += 1
if num_lags == 0:
return match_score
return -match_score / float(num_lags)

Copilot uses AI. Check for mistakes.
Comment on lines +80 to +88
# Compressor-Aware Training (CAT)
cat_lambda_lz = float(os.environ.get("CAT_LAMBDA_LZ", 0.0))
cat_lambda_h = float(os.environ.get("CAT_LAMBDA_H", 0.0))
cat_temperature = float(os.environ.get("CAT_TEMPERATURE", 50.0))
cat_bandwidth = float(os.environ.get("CAT_BANDWIDTH", 1.0))
cat_sample_size = int(os.environ.get("CAT_SAMPLE_SIZE", 100_000))
cat_start_step = int(os.environ.get("CAT_START_STEP", 0))
cat_log_every = int(os.environ.get("CAT_LOG_EVERY", 50))

Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CAT hyperparameters can be set via env vars, but there’s no validation that cat_temperature and cat_bandwidth are positive. As written, cat_bandwidth=0 (or negative values) will cause a division-by-zero / invalid logits in entropy_proxy_loss_torch, and non-positive temperatures will invert/degenerate the LZ proxy. Consider validating these in Hyperparameters or right before computing the CAT losses.

Copilot uses AI. Check for mistakes.
Comment on lines +612 to +633
def serialize_quantized_weights_torch(model: nn.Module) -> torch.Tensor:
"""Serialize fake-quantized weights to flat float tensor of byte values.

Uses STE (straight-through estimator) so gradients flow through round().
Without STE, torch.round() has zero gradient and CAT loss cannot update weights.
Respects QUANT_BITS so the proxy matches the actual quantization grid.
"""
qmax = (1 << (QUANT_BITS - 1)) - 1 # 31 for int6, 127 for int8
byte_chunks = []
for name, param in model.named_parameters():
if param.ndim < 2 or param.numel() <= 65536:
continue
w = param.float()
row_max = w.abs().amax(dim=-1, keepdim=True)
scale = torch.clamp(row_max / float(qmax), min=1.0 / float(qmax))
w_clamped = torch.clamp(w / scale, -qmax, qmax)
w_rounded = torch.round(w_clamped)
# STE: forward uses rounded values, backward flows through w_clamped
w_q = (w_rounded - w_clamped).detach() + w_clamped
w_bytes = w_q + float(qmax + 1) # shift to unsigned range
byte_chunks.append(w_bytes.reshape(-1))
return torch.cat(byte_chunks)
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serialize_quantized_weights_torch claims to emit raw byte values in [0–255], but the current mapping shifts by qmax + 1, yielding [1..255] for int8 (0 never appears) and [1..63] for int6. This mismatch affects the entropy proxy and contradicts the docstring/README description; consider either adjusting the mapping to a true 0-based range or updating the docs to match the actual value range.

Copilot uses AI. Check for mistakes.
if args.cat_lambda_lz > 0 or args.cat_lambda_h > 0:
if step >= args.cat_start_step:
cat_device = next(base_model.parameters()).device
byte_stream = serialize_quantized_weights_torch(base_model)
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CAT currently materializes a concatenated byte_stream over (most) model parameters every step, then subsamples from it. For larger models this is O(#params) extra work and memory even when cat_sample_size is small, and it is duplicated per-rank under DDP. Consider sampling from parameter tensors directly (without building the full stream), or caching/reusing the serialization for multiple steps when possible, to reduce overhead and OOM risk.

Suggested change
byte_stream = serialize_quantized_weights_torch(base_model)
named_cat_params = [(name, param) for name, param in base_model.named_parameters()]
total_cat_elems = sum(param.numel() for _, param in named_cat_params)
cat_sample_budget = int(args.cat_sample_size)
if 0 < cat_sample_budget < total_cat_elems:
rng = random.Random(step)
param_indices = list(range(len(named_cat_params)))
rng.shuffle(param_indices)
sampled_module = nn.Module()
sampled_elems = 0
sampled_idx = 0
while sampled_idx < len(param_indices) and sampled_elems < cat_sample_budget:
name, param = named_cat_params[param_indices[sampled_idx]]
sampled_module.register_parameter(
f"cat_param_{sampled_idx}_{name.replace('.', '_')}",
param,
)
sampled_elems += param.numel()
sampled_idx += 1
byte_stream = serialize_quantized_weights_torch(sampled_module)
else:
byte_stream = serialize_quantized_weights_torch(base_model)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants