Non-record: Compressor-Aware Training (CAT), differentiable compression proxies for LZ-family compressors#1385
Non-record: Compressor-Aware Training (CAT), differentiable compression proxies for LZ-family compressors#1385korentomas wants to merge 2 commits intoopenai:mainfrom
Conversation
Differentiable proxies for LZ-family compression as a training regularizer. Dictionary matching proxy (multi-lag autocorrelation) + entropy proxy (soft histogram). 6.8% artifact size reduction on 1xH100 at +0.009 BPB cost.
There was a problem hiding this comment.
Pull request overview
Adds a non-record submission demonstrating Compressor-Aware Training (CAT): differentiable proxies for LZ-family compression (dictionary matching + entropy coding) used as a training regularizer to produce more zlib-friendly quantized weights under the 16MB artifact constraint.
Changes:
- Introduces a new training script (
train_golf.py) implementing CAT losses (LZ77-style multi-lag soft autocorrelation + soft-histogram entropy) and logging compressed-size diagnostics. - Adds a submission manifest (
submission.json) capturing metrics and architecture/hyperparameters for the run. - Adds a technique writeup (
README.md) describing motivation, proxy losses, and ablation results.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/train_golf.py |
Implements CAT regularizers + training/quantization/export pipeline for the non-record submission. |
records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/submission.json |
Records run metadata and CAT hyperparameters for the submission. |
records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/README.md |
Documents CAT approach, debugging notes, and experimental results. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # CAT regularizer (separate backward pass, outside torch.compile region). | ||
| # Uses base_model (not DDP wrapper) — bypasses DDP allreduce hooks. | ||
| # This is correct because: (1) DDP keeps weights synchronized across ranks, | ||
| # (2) seed=step ensures identical subsampling, so all ranks compute identical | ||
| # CAT gradients independently. If either invariant breaks, ranks will diverge. | ||
| cat_loss_val = 0.0 | ||
| if args.cat_lambda_lz > 0 or args.cat_lambda_h > 0: | ||
| if step >= args.cat_start_step: | ||
| cat_device = next(base_model.parameters()).device | ||
| byte_stream = serialize_quantized_weights_torch(base_model) | ||
| cat_loss = torch.tensor(0.0, device=cat_device) | ||
| if args.cat_lambda_lz > 0: | ||
| cat_loss = cat_loss + args.cat_lambda_lz * lz77_proxy_loss_torch( | ||
| byte_stream, args.cat_temperature, args.cat_sample_size, seed=step) | ||
| if args.cat_lambda_h > 0: | ||
| cat_loss = cat_loss + args.cat_lambda_h * entropy_proxy_loss_torch( | ||
| byte_stream, args.cat_bandwidth, args.cat_sample_size, seed=step) | ||
| cat_loss.backward() | ||
| cat_loss_val = cat_loss.item() |
There was a problem hiding this comment.
In distributed (DDP) runs, the CAT backward pass is executed on base_model outside the DDP wrapper, so its gradients are never all-reduced. The comment assumes all ranks will compute identical CAT gradients, but small nondeterminism (or future refactors) would cause rank divergence. Consider either computing CAT loss through the DDP-wrapped model on a synced backward step, or explicitly all-reducing the affected parameter grads after cat_loss.backward() when distributed is true.
| for lag in lags: | ||
| if lag >= x.shape[0]: | ||
| break | ||
| diff_sq = (x[lag:] - x[:-lag]).square() | ||
| match_score = match_score + torch.exp(-diff_sq / temperature).mean() | ||
| return -match_score / float(len(lags)) |
There was a problem hiding this comment.
lz77_proxy_loss_torch breaks out of the lag loop when lag >= x.shape[0], but still divides by len(lags) (10). For shorter byte_stream inputs this changes the loss scale depending on length. Track the actual number of lags accumulated and divide by that instead (or early-return when x is too short).
| for lag in lags: | |
| if lag >= x.shape[0]: | |
| break | |
| diff_sq = (x[lag:] - x[:-lag]).square() | |
| match_score = match_score + torch.exp(-diff_sq / temperature).mean() | |
| return -match_score / float(len(lags)) | |
| num_lags = 0 | |
| for lag in lags: | |
| if lag >= x.shape[0]: | |
| break | |
| diff_sq = (x[lag:] - x[:-lag]).square() | |
| match_score = match_score + torch.exp(-diff_sq / temperature).mean() | |
| num_lags += 1 | |
| if num_lags == 0: | |
| return match_score | |
| return -match_score / float(num_lags) |
| # Compressor-Aware Training (CAT) | ||
| cat_lambda_lz = float(os.environ.get("CAT_LAMBDA_LZ", 0.0)) | ||
| cat_lambda_h = float(os.environ.get("CAT_LAMBDA_H", 0.0)) | ||
| cat_temperature = float(os.environ.get("CAT_TEMPERATURE", 50.0)) | ||
| cat_bandwidth = float(os.environ.get("CAT_BANDWIDTH", 1.0)) | ||
| cat_sample_size = int(os.environ.get("CAT_SAMPLE_SIZE", 100_000)) | ||
| cat_start_step = int(os.environ.get("CAT_START_STEP", 0)) | ||
| cat_log_every = int(os.environ.get("CAT_LOG_EVERY", 50)) | ||
|
|
There was a problem hiding this comment.
CAT hyperparameters can be set via env vars, but there’s no validation that cat_temperature and cat_bandwidth are positive. As written, cat_bandwidth=0 (or negative values) will cause a division-by-zero / invalid logits in entropy_proxy_loss_torch, and non-positive temperatures will invert/degenerate the LZ proxy. Consider validating these in Hyperparameters or right before computing the CAT losses.
| def serialize_quantized_weights_torch(model: nn.Module) -> torch.Tensor: | ||
| """Serialize fake-quantized weights to flat float tensor of byte values. | ||
|
|
||
| Uses STE (straight-through estimator) so gradients flow through round(). | ||
| Without STE, torch.round() has zero gradient and CAT loss cannot update weights. | ||
| Respects QUANT_BITS so the proxy matches the actual quantization grid. | ||
| """ | ||
| qmax = (1 << (QUANT_BITS - 1)) - 1 # 31 for int6, 127 for int8 | ||
| byte_chunks = [] | ||
| for name, param in model.named_parameters(): | ||
| if param.ndim < 2 or param.numel() <= 65536: | ||
| continue | ||
| w = param.float() | ||
| row_max = w.abs().amax(dim=-1, keepdim=True) | ||
| scale = torch.clamp(row_max / float(qmax), min=1.0 / float(qmax)) | ||
| w_clamped = torch.clamp(w / scale, -qmax, qmax) | ||
| w_rounded = torch.round(w_clamped) | ||
| # STE: forward uses rounded values, backward flows through w_clamped | ||
| w_q = (w_rounded - w_clamped).detach() + w_clamped | ||
| w_bytes = w_q + float(qmax + 1) # shift to unsigned range | ||
| byte_chunks.append(w_bytes.reshape(-1)) | ||
| return torch.cat(byte_chunks) |
There was a problem hiding this comment.
serialize_quantized_weights_torch claims to emit raw byte values in [0–255], but the current mapping shifts by qmax + 1, yielding [1..255] for int8 (0 never appears) and [1..63] for int6. This mismatch affects the entropy proxy and contradicts the docstring/README description; consider either adjusting the mapping to a true 0-based range or updating the docs to match the actual value range.
| if args.cat_lambda_lz > 0 or args.cat_lambda_h > 0: | ||
| if step >= args.cat_start_step: | ||
| cat_device = next(base_model.parameters()).device | ||
| byte_stream = serialize_quantized_weights_torch(base_model) |
There was a problem hiding this comment.
CAT currently materializes a concatenated byte_stream over (most) model parameters every step, then subsamples from it. For larger models this is O(#params) extra work and memory even when cat_sample_size is small, and it is duplicated per-rank under DDP. Consider sampling from parameter tensors directly (without building the full stream), or caching/reusing the serialization for multiple steps when possible, to reduce overhead and OOM risk.
| byte_stream = serialize_quantized_weights_torch(base_model) | |
| named_cat_params = [(name, param) for name, param in base_model.named_parameters()] | |
| total_cat_elems = sum(param.numel() for _, param in named_cat_params) | |
| cat_sample_budget = int(args.cat_sample_size) | |
| if 0 < cat_sample_budget < total_cat_elems: | |
| rng = random.Random(step) | |
| param_indices = list(range(len(named_cat_params))) | |
| rng.shuffle(param_indices) | |
| sampled_module = nn.Module() | |
| sampled_elems = 0 | |
| sampled_idx = 0 | |
| while sampled_idx < len(param_indices) and sampled_elems < cat_sample_budget: | |
| name, param = named_cat_params[param_indices[sampled_idx]] | |
| sampled_module.register_parameter( | |
| f"cat_param_{sampled_idx}_{name.replace('.', '_')}", | |
| param, | |
| ) | |
| sampled_elems += param.numel() | |
| sampled_idx += 1 | |
| byte_stream = serialize_quantized_weights_torch(sampled_module) | |
| else: | |
| byte_stream = serialize_quantized_weights_torch(base_model) |
Summary
Results (1xH100, 5 runs)
Full writeup, prior art analysis, and debugging notes in the README.