This library is intended for the training and analysis of cross-layer sparse coding models, including the Cross-Layer Transcoder as described in "Circuit Tracing: Revealing Computational Graphs in Language Models". Currently, that is the only type of architecture supported, but in the future this will support other related models. The primary (and recommended) training strategy for this library is BatchTopK
, converting this to a JumpReLU
model afterwards. However, I've also included support for JumpReLU
-native training, as well as ReLU
and TopK
, largely for the purpose of experimentation.
Key features:
- Fully tensor-parallel (expands across GPUs via the feature dimension)
- Can train from locally-saved activations, activations saved on the included server and streamed to the training machines, or (soon) from streaming activations
- Key activation functions and variants implemented
A Cross-Layer Transcoder (CLT) is a multi-layer dictionary learning model designed to extract sparse, interpretable features from transformers, using an encoder for each layer and a decoder for each (source layer, destination layer) pair (e.g., 12 encoders and 78 decoders for gpt2-small
). This implementation focuses on the core functionality needed to train and use CLTs, leveraging nnsight
for model introspection and datasets
for data handling.
Training a CLT involves the following steps:
- Pre-generate activations with
scripts/generate_activations
(though an implementation ofStreamingActivationStore
is on the way). - Train a CLT (start with an expansion factor of at least
32
) using this data. Metrics can be logged to WandB. NMSE should get below0.25
, or ideally even below0.10
. As mentioned above, I recommendBatchTopK
training, and suggest keepingK
low--200
is a good place to start. - Convert the model to a
JumpReLU
model usingconvert_batch_topk_to_jumprelu.py
. This will estimate a threshold using the formula from the BatchTopK paper. However, this script also implements an additional layerwise calibration step that in practice often performs model performance even beyond what it was at the end of training.
The model will be saved as a safetensors
object that can be used for other steps, like dashboard generation with SAEDashboard
or attribution graph computation with a forthcoming library soon to be linked here.
# Ensure you have Python 3.8+ and pip installed
git clone https://github.com/curt-tigges/crosslayer-coding.git
cd crosslayer-coding
pip install -e .
# Install optional dependencies if needed (e.g., for HDF5 support, WandB)
# pip install h5py wandb torch --index-url https://download.pytorch.org/whl/cu118 # Example for specific CUDA
For training with locally stored activations (--activation-source local_manifest
), you first need to generate an activation dataset. This is done using the scripts/generate_activations.py
script. This script extracts MLP input and output activations from a specified Hugging Face model using a given dataset, saving them in HDF5 chunks along with a manifest file (index.bin
) and metadata (metadata.json
, norm_stats.json
).
Key Arguments for scripts/generate_activations.py
:
--model-name
: Hugging Face model name or path (e.g.,gpt2
).--mlp-input-template
,--mlp-output-template
: NNsight path templates for MLP activations.--dataset-path
: Hugging Face dataset name or path (e.g.,monology/pile-uncopyrighted
).--activation-dir
: Base directory to save the generated activation dataset.--target-total-tokens
: Approximate total number of tokens to generate activations for.--chunk-token-threshold
: Number of tokens to accumulate before saving a chunk.--activation-dtype
: Precision for storing activations (e.g.,bfloat16
,float16
,float32
).- Other arguments control tokenization, batching, storage, and nnsight parameters. Run
python scripts/generate_activations.py --help
for details.
Example Command:
python scripts/generate_activations.py \
--model-name gpt2 \
--dataset-path monology/pile-uncopyrighted \
--mlp-input-template "transformer.h.{}.mlp.c_fc" \
--mlp-output-template "transformer.h.{}.mlp.c_proj" \
--activation-dir ./tutorial_activations \
--target-total-tokens 2000000 \
--chunk-token-threshold 1000000 \
--activation_dtype bfloat16 \
--compute-norm-stats
The easiest way to train a CLT is using the scripts/train_clt.py
script. This script parses configuration directly from command-line arguments.
Key Arguments:
--activation-source
: Must belocal_manifest
orremote
.--num-features
: Number of CLT features per layer.--model-name
: Base model name (e.g., 'gpt2'), used for CLT dimension inference.- Arguments related to
CLTConfig
andTrainingConfig
(prefixed appropriately, e.g.,--learning-rate
,--sparsity-lambda
). See details below. --activation-path
: Required only if--activation-source=local_manifest
.--server-url
,--dataset-id
: Required only if--activation-source=remote
.
Run python scripts/train_clt.py --help
for a full list of arguments and their defaults.
Configuration via Arguments:
Key configuration parameters are mapped to config classes via script arguments:
- CLTConfig:
--num-features
,--activation-fn
,--jumprelu-threshold
,--clt-dtype
,--batchtopk-k
, etc. (num_layers
,d_model
are derived from--model-name
). The--activation-fn
argument allows choosing between different feature activation functions:jumprelu
(default): A learned, per-feature thresholded ReLU.relu
: Standard ReLU activation.batchtopk
: Selects a global top K features across all tokens in a batch, based on pre-activation values. The 'k' can be an absolute number or a fraction. This is often used as a training-time differentiable approximation that can later be converted tojumprelu
.topk
: Selects top K features per token (row-wise top-k).
- TrainingConfig:
--learning-rate
,--training-steps
,--train-batch-size-tokens
,--activation-source
,--activation-path
(forlocal_manifest
), remote config fields (forremote
, e.g.--server-url
,--dataset-id
),--normalization-method
,--sparsity-lambda
,--preactivation-coef
,--optimizer
,--lr-scheduler
,--log-interval
,--eval-interval
,--checkpoint-interval
,--dead-feature-window
, WandB settings (--enable-wandb
,--wandb-project
, etc.).
Example: Training from Pre-Generated Local Activations (local_manifest
)
This mode requires activations generated beforehand (see previous section).
python scripts/train_clt.py \
--activation-source local_manifest \
--activation-path ./tutorial_activations/gpt2/pile-uncopyrighted_train \
--output-dir ./clt_output_local \
--model-name gpt2 \
--num-features 3072 \
--activation-fn jumprelu \
--learning-rate 3e-4 \\
--training-steps 50000 \
--train-batch-size-tokens 4096 \
--sparsity-lambda 1e-3 \
--normalization-method auto \
--log-interval 100 \
--eval-interval 1000 \
--checkpoint-interval 1000 \
--enable-wandb --wandb-project clt_training_local
# Add other arguments as needed
Example: Training from a Remote Activation Server (remote
)
This mode fetches activations from a running clt_server
instance. See the section on the Remote Activation Server below.
python scripts/train_clt.py \\
--activation-source remote \\
--server-url http://localhost:8000 \\
--dataset-id gpt2/pile-uncopyrighted_train \\
--output-dir ./clt_output_remote \\
--model-name gpt2 \\
--num-features 3072 \\
--activation-fn jumprelu \\
--learning-rate 3e-4 \\
--training-steps 50000 \\
--train-batch-size-tokens 4096 \\
--sparsity-lambda 1e-3 \\
--normalization-method auto \\
--log-interval 100 \\
--eval-interval 1000 \\
--checkpoint-interval 1000 \\
--enable-wandb --wandb-project clt_training_remote
# Add other arguments as needed
This library supports feature-wise tensor parallelism using PyTorch Distributed Data Parallel (torch.distributed
). This shards the model's parameters (encoders, decoders) across multiple GPUs, reducing memory usage per GPU and potentially speeding up computation.
How it Works:
- The
CLTTrainer
automatically detects if it's launched in a distributed environment (via environment variables set by launchers liketorchrun
). - If distributed, it initializes a process group (
nccl
backend recommended for NVIDIA GPUs). - The
CrossLayerTranscoder
usesColumnParallelLinear
andRowParallelLinear
which handle the weight sharding and necessary communication primitives (all_gather
,all_reduce
) automatically. - The
ActivationStore
(bothLocal
andRemote
) is configured internally to not shard the data (shard_data=False
) when distributed, ensuring all ranks in the tensor parallel group receive the same token batch, which is required for feature parallelism.
Launch Script:
Use torchrun
(or torch.distributed.launch
) to start the training script on multiple GPUs. torchrun
handles setting the necessary environment variables (MASTER_ADDR
, MASTER_PORT
, RANK
, WORLD_SIZE
, LOCAL_RANK
).
Example: Training on 4 GPUs (Single Node)
# Example using local manifest data on 4 GPUs
torchrun --nproc-per-node=4 scripts/train_clt.py \
--activation-source local_manifest \
--activation-path ./tutorial_activations/gpt2/pile-uncopyrighted_train \
--output-dir ./clt_output_local_4gpu \
--model-name gpt2 \
--num-features 3072 \
--activation-fn jumprelu \
--learning-rate 3e-4 \
--training-steps 50000 \
--train-batch-size-tokens 4096 \
--sparsity-lambda 1e-3 \
--normalization-method auto \
--log-interval 100 \
--eval-interval 1000 \
--checkpoint-interval 1000 \
--enable-wandb --wandb-project clt_training_local_4gpu
# Add other arguments as needed
Multi-Node Training: Scaling to multiple nodes requires:
- Setting up the
torchrun
command appropriately with--nnodes
,--node-rank
,--rdzv-id
,--rdzv-backend
, and--rdzv-endpoint
. See the torchrun documentation. - Ensuring data accessibility for all nodes:
- Local Manifest: Requires a shared filesystem mounted at the same path on all nodes.
- Remote Server: The activation server URL must be reachable from all training nodes. The central server might become a bottleneck; make sure it is fast and has a good connection.
- Ensuring the
--output-dir
is on a shared filesystem for checkpointing, or implementing custom checkpointing logic to save/load shards from a central location (e.g., cloud storage). The defaultCheckpointManager
assumes ranks can write to the same directory structure.
To resume training from a previously saved checkpoint, use the --resume-from-checkpoint-dir
argument with the scripts/train_clt.py
command. The script will attempt to load the latest checkpoint (clt_checkpoint_latest.safetensors
and trainer_state_latest.pt
for non-distributed, or the latest/
directory for distributed runs) from the specified directory.
Key aspects of resuming:
- Configuration Loading: When resuming, the script will look for
cli_args.json
in the--resume-from-checkpoint-dir
. If found, it loads the command-line arguments from the original run. You can override certain parameters by providing them in the current command (e.g., extend--training-steps
). Ifcli_args.json
is not found, a warning is issued, and the current command-line arguments are used (you must ensure all necessary configurations are provided). - Output Directory: The
output_dir
for the resumed run will be the same as the--resume-from-checkpoint-dir
. - WandB Resumption: If the original run used Weights & Biases (WandB) and the WandB run ID was saved in the checkpoint (
trainer_state_latest.pt
), the resumed training will attempt to continue logging to the same WandB run. Do not specify--wandb-run-name
when resuming if you want to continue the original run; the ID from the checkpoint will be used. - Specific Step: You can resume from a specific checkpoint step (instead of
latest
) by also providing the--resume-step <step_number>
argument. For non-distributed runs, it will look forclt_checkpoint_<step_number>.safetensors
. For distributed runs, it will look for thestep_<step_number>/
directory. - State Restoration: The trainer restores the model weights, optimizer state, scheduler state, gradient scaler state, and the state of the data sampler (including RNG states for PyTorch, NumPy, and Python's
random
module).
Example: Resuming a Non-Distributed Run
python scripts/train_clt.py \
--resume-from-checkpoint-dir ./clt_output_local \
# Optional: --resume_step 10000 # To resume from step 10000 specifically
# Optional: --training-steps 60000 # To extend training beyond original steps (this number should be total including past steps)
# Ensure other necessary args are present if cli_args.json is missing or you need to override them.
# For example, if the original run used WandB, include --enable-wandb if not in cli_args.json.
--enable-wandb
Example: Resuming a Distributed Run (e.g., 4 GPUs)
torchrun --nproc-per-node=4 scripts/train_clt.py \
--resume-from-checkpoint-dir ./clt_output_local_4gpu \
# Optional: --resume_step 10000
# Optional: --training-steps 60000
--enable-wandb
Important Notes for Resuming:
- Ensure the activation data specified in the original
cli_args.json
(or provided in the resume command ifcli_args.json
is missing) is still accessible at the same path. - If you modify critical architectural parameters (e.g.,
--num-features
,--model-name
leading to differentd_model
ornum_layers
) when resuming, it will likely lead to errors when loading the model weights.
Using half-precision (like float16 or bfloat16) can significantly reduce memory footprint and potentially speed up both activation generation and model training, especially on compatible hardware.
1. For Activation Generation:
When generating activation datasets with scripts/generate_activations.py
:
- Use the
--activation-dtype
argument to specify the precision for storing the activations. Options includefloat16
,bfloat16
, orfloat32
(default).python scripts/generate_activations.py \ # ... other arguments ... \ --activation-dtype float16 # or bfloat16
- Benefit: Storing activations in
float16
orbfloat16
reduces disk space by roughly half compared tofloat32
.
2. For CLT Model Training:
When training a CLT model with scripts/train_clt.py
:
- Use the
--precision
argument to enable Automatic Mixed Precision (AMP) during training. Options are:fp16
: Uses float16 for many operations. Requires a GPU with good fp16 support (e.g., NVIDIA Turing architecture or newer).bf16
: Uses bfloat16. Generally more stable than fp16 for training and often preferred on newer GPUs that support it well (e.g., NVIDIA Ampere A100, Hopper H100).fp32
: Standard float32 training (default).
# Example for bf16 training torchrun --nproc-per-node=<num_gpus> scripts/train_clt.py \ # ... other arguments ... \ --precision bf16 # Example for fp16 training torchrun --nproc-per-node=<num_gpus> scripts/train_clt.py \ # ... other arguments ... \ --precision fp16
- Benefit: Reduces GPU memory usage significantly, allowing for larger models, more features, or bigger batch sizes. Can also lead to faster training on compatible hardware.
--fp16-convert-weights
: If you use--precision fp16
, you can also add the--fp16-convert-weights
flag. This will convert the model's actual weight parameters tofloat16
in addition to using AMP. By default (--fp16-convert-weights
not set), the master weights are kept infloat32
when using AMP withfp16
. Converting weights further reduces memory but might slightly impact precision or training stability for some models.--clt-dtype
: You can also specify the data type for the CLT model's parameters directly using--clt-dtype
(e.g.,float16
,bfloat16
). If using--precision fp16
with--fp16-convert-weights
, the model weights effectively become fp16. If using--precision bf16
, the model weights also effectively become bf16. Setting--clt-dtype
explicitly might be useful in specific scenarios or if not using the trainer's precision handling, but typically--precision
is the primary way to control training precision.
We strongly recommend using batchtopk
(or topk
) as the activation function (--activation-fn batchtopk
). This allows the model to learn sparse features by dynamically selecting the top 'K' features globally (BatchTopK) or per-token (TopK).
After training, these implicit thresholds can be converted to explicit, fixed per-feature thresholds for a jumprelu
activation function. This conversion is performed by the scripts/convert_batchtopk_to_jumprelu.py
script. The process involves:
- Initial Theta Estimation: The script first estimates initial per-feature thresholds by analyzing the minimum selected pre-activation values of features from the original model over a dataset.
- Layer-wise L0 Calibration (Crucial Step): Since the initial conversion might not perfectly replicate the layer-wise L0 sparsity of the original model, an optional but highly recommended calibration step is performed if
--l0-layerwise-calibrate
is set. This step:- Determines the target L0 norm (average number of active features per token) for each layer from the original BatchTopK/TopK model.
- Adjusts the per-feature JumpReLU thresholds in the converted model, layer by layer, to match these target L0s. This is done by finding a scaling factor for each layer's thresholds via binary search. This calibration helps ensure the converted JumpReLU model closely mimics the sparsity characteristics of the original, better preserving its performance.
Key Arguments for scripts/convert_batchtopk_to_jumprelu.py
:
--batchtopk-checkpoint-path
: Path to the saved BatchTopK model checkpoint directory (e.g., containingclt.pt
andcfg.json
).--config-path
: Path to the JSON config file of the BatchTopK model (usually within the checkpoint dir).--activation-data-path
: Path to an activation dataset (manifest directory) for theta estimation and calibration.--output-model-path
: Path to save the converted JumpReLU model's state_dict.--output-config-path
: Path to save the converted JumpReLU model's config.--num-batches-for-theta-estimation
: Number of batches to use for initial theta estimation.--default-theta_value
: Default threshold for features that never activated during estimation.--l0-layerwise_calibrate
: Flag to enable the layer-wise L0 calibration (recommended).--l0-calibration-batches
,--l0-calibration-batch-size-tokens
: Parameters for data used during L0 calibration.--l0-target-model-config-path
,--l0-target-model-checkpoint-path
: Paths to the original model if different from the main input, for deriving L0 targets.--l0-calibration-tolerance
,--l0-calibration-search-min-scale
,--l0-calibration-search-max-scale
,--l0-calibration-max-iters
: Control parameters for the layer-wise calibration search.
Run python scripts/convert_batchtopk_to_jumprelu.py --help
for details.
Example Command:
python scripts/convert_batchtopk_to_jumprelu.py \
--batchtopk-checkpoint-path /path/to/your/batchtopk_model_checkpoint_dir \
--config-path /path/to/your/batchtopk_model_config.json \
--activation-data-path /path/to/your/activation_dataset_for_estimation_and_calibration \
--output-model-path /path/to/converted_jumprelu_model.pt \
--output-config-path /path/to/converted_jumprelu_config.json \
--num-batches-for-theta-estimation 100 \
--l0-layerwise-calibrate \
--l0-calibration-batches 10 \
--l0-calibration-tolerance 0.5 # Adjust as needed
For large datasets or collaborative environments, activations can be served from a central server using the clt_server
component (located in the clt_server/
directory). The ActivationGenerator
can be configured to upload generated chunks to this server, and the CLTTrainer
can use RemoteActivationStore
to fetch batches during training.
Server Functionality:
- Stores activation chunks (HDF5 files),
metadata.json
, andnorm_stats.json
uploaded via HTTP. - Provides an API for
RemoteActivationStore
to download the manifest (index.bin
), metadata, normalization statistics, and request specific slices of tokens from the stored chunks.
Workflow:
- Generate and Upload: Use
scripts/generate_activations.py
with--remote_server_url <your_server_address>
and--storage_type remote
(thoughstorage_type
is handled byActivationGenerator.set_storage_type
and the script itself doesn't directly use astorage_type
arg forActivationConfig
anymore - the uploader inActivationGenerator
is activated ifremote_server_url
is provided inActivationConfig
). This will generate activations and upload them to the specified server. - Train Remotely: Use
scripts/train_clt.py
with--activation-source remote
, providing the--server-url
and the--dataset-id
(which is typically<model_name>/<dataset_name>_<split>
). TheRemoteActivationStore
will then fetch data from the server.
For detailed instructions on setting up and running the clt_server
, please refer to its dedicated README: clt_server/README.md
.
The script scripts/scramble_dataset.py
can be used to take an existing locally stored dataset (generated by generate_activations.py
) and create a new version where all activation rows are globally shuffled across all chunks. This is useful if you want to train using random samples from the entire dataset without relying on the random_chunk
sampling strategy during training.
python scripts/scramble_dataset.py \
--input-dir /path/to/original/dataset \
--output-dir /path/to/scrambled/dataset \
--seed 42 # Optional seed for reproducibility
This creates a new directory (/path/to/scrambled/dataset
) containing the shuffled HDF5 chunks and a corresponding corrected index.bin
manifest. You can then use this new directory path as the --activation-path
when training with --activation-source local_manifest
.
clt/
config/ # Configuration dataclasses (ActivationConfig, CLTConfig, TrainingConfig, InferenceConfig)
models/ # Model implementations (BaseTranscoder, CrossLayerTranscoder, Parallel Layers, Activations)
training/ # Training components (CLTTrainer, LossManager, CLTEvaluator, Checkpointing, Distributed Utils)
data/ # Activation store implementations (Base, Manifest, Local, Remote, Factory, Sampler)
nnsight/ # NNsight integration (ActivationExtractorCLT)
activation_generation/ # Activation pre-generation (ActivationGenerator)
utils/ # Utility functions (minimal)
scripts/ # Example scripts (e.g., train_clt.py, generate_activations.py, scramble_dataset.py, analyze_theta.py, convert_batchtopk_to_jumprelu.py)
clt_server/ # Optional: Remote activation server application
ActivationConfig
: Dataclass for activation data source and generation parameters (inclt/config/data_config.py
). Primarily used byscripts/generate_activations.py
.CLTConfig
: Dataclass for CLT architecture parameters (inclt/config/clt_config.py
).TrainingConfig
: Dataclass for training loop, data source selection (local_manifest
orremote
), and hyperparameters (inclt/config/clt_config.py
).InferenceConfig
: Dataclass for CLT inference/evaluation parameters (inclt/config/clt_config.py
).CrossLayerTranscoder
: The core CLT model implementation (inclt/models/clt.py
). Handles different activation functions and tensor parallelism via parallel linear layers.ColumnParallelLinear
,RowParallelLinear
: Implementations of tensor-parallel linear layers (inclt/models/parallel.py
).Activation Functions
:JumpReLU
,BatchTopK
,TokenTopK
implementations (inclt/models/activations.py
).ActivationExtractorCLT
: Extracts MLP activations from a base model usingnnsight
(inclt/nnsight/extractor.py
). Used byActivationGenerator
.ActivationGenerator
: Generates and saves activations based onActivationConfig
(inclt/activation_generation/generator.py
).BaseActivationStore
: Abstract base class for activation stores (inclt/training/data/base_store.py
).ManifestActivationStore
: Base class for stores using a manifest file (inclt/training/data/manifest_activation_store.py
). IncludesChunkRowSampler
.LocalActivationStore
: Manages activation data from local HDF5 files using a manifest (inclt/training/data/local_activation_store.py
).RemoteActivationStore
: Manages activation data fetched from a remote server using a manifest (inclt/training/data/remote_activation_store.py
).create_activation_store
: Factory function to instantiate the correct store based on config (inclt/training/data/activation_store_factory.py
).LossManager
: Calculates reconstruction, sparsity, and pre-activation losses (inclt/training/losses.py
).CLTEvaluator
: Computes evaluation metrics like L0, feature density, and explained variance (inclt/training/evaluator.py
).CheckpointManager
: Handles saving and loading model/training state, supporting distributed training (inclt/training/checkpointing.py
).MetricLogger
: Handles logging metrics to console, file, and WandB (inclt/training/metric_utils.py
).CLTTrainer
: Orchestrates the training process, integrating all components and handling distributed setup (inclt/training/trainer.py
). Selects the appropriate activation store.
This example shows how to set up and run the trainer programmatically using pre-generated local activations.
import torch
from pathlib import Path
# ActivationConfig is used by scripts/generate_activations.py, not directly by TrainingConfig here.
from clt.config import CLTConfig, TrainingConfig
from clt.training.trainer import CLTTrainer
# --- Configuration ---
device = "cuda" if torch.cuda.is_available() else "cpu"
output_dir = Path("clt_programmatic_output")
output_dir.mkdir(exist_ok=True, parents=True)
# Determine base model dimensions (e.g., for GPT-2)
# In a script, you might use get_model_dimensions from train_clt.py or infer from config
num_layers = 12
d_model = 768
clt_config = CLTConfig(
num_features=3072, # Example: 4x expansion
num_layers=num_layers,
d_model=d_model,
activation_fn="jumprelu",
jumprelu_threshold=0.03,
model_name="gpt2" # Store base model name for reference
# clt_dtype="bfloat16", # Optional: Specify CLT model dtype
)
# Configure for using local pre-generated data
# Ensure activations are generated first, e.g., via scripts/generate_activations.py
activation_data_dir = "./tutorial_activations/gpt2/pile-uncopyrighted_train" # Assumes this exists
training_config_local = TrainingConfig(
# Core training parameters
learning_rate=3e-4,
training_steps=20000,
train_batch_size_tokens=4096,
# Activation Source: Local Manifest
activation_source="local_manifest",
activation_path=activation_data_dir, # Path to directory with index.bin, metadata.json, etc.
activation_dtype="bfloat16", # Or "float16", "float32" - type for activations from store
# Normalization: "auto" uses norm_stats.json from activation_path if available
# Other options: "none"
normalization_method="auto",
# Loss coefficients
sparsity_lambda=1e-3,
sparsity_c=1.0,
preactivation_coef=3e-6,
# Logging and Evaluation
log_interval=100,
eval_interval=1000,
checkpoint_interval=1000,
# WandB (optional)
enable_wandb=False,
# wandb_project="my-clt-project-local",
)
# --- Trainer Initialization ---
# Note: For multi-GPU, this would be run inside a script launched by torchrun.
# The trainer handles distributed init internally based on env vars.
trainer = CLTTrainer(
clt_config=clt_config,
training_config=training_config_local,
log_dir=str(output_dir),
device=device, # For single GPU; ignored if distributed=True (derived from LOCAL_RANK)
)
# --- Run Training ---
print("Starting training...")
trained_clt_model = trainer.train()
print(f"Training complete! Final model saved in {output_dir}")
# --- Saving/Loading (Trainer handles checkpointing and final model saving) ---
# The trainer saves the model config (cfg.json) and model weights.
# For a model trained with BatchTopK, you might need to convert it to JumpReLU post-training.
# Example:
# from clt.scripts.convert_batchtopk_to_jumprelu import convert_model_to_jumprelu
# clt_checkpoint_path = output_dir / "final"
# config_path = clt_checkpoint_path / "cfg.json"
# converted_model_dir = output_dir / "final_jumprelu"
# converted_model_path = converted_model_dir / "clt_model_jumprelu.pt"
# converted_config_path = converted_model_dir / "cfg_jumprelu.json"
# if config_path.exists():
# # Assuming convert_model_to_jumprelu takes dir path now
# # Needs adjustment based on actual script signature
# # convert_model_to_jumprelu(...)
# print(f"Converted model to JumpReLU and saved to {converted_model_dir}")
# To load a model manually (e.g., after training):
# from clt.models import CrossLayerTranscoder
# loaded_config = CLTConfig.from_json(output_dir / "final" / "cfg.json")
# # For loading a single-GPU or a specific rank's shard:
# loaded_model = CrossLayerTranscoder(loaded_config, process_group=None, device=torch.device(device))
# # Load the state dict (adjust path if loading a specific rank shard)
# state_dict_path = output_dir / "final" / "clt.pt" # Or rank_0_model.pt if saved sharded
# loaded_model.load_state_dict(torch.load(state_dict_path, map_location=device))
# print("Model loaded manually.")
If you use Crosslayer-Coding in your research, please cite:
@software{crosslayer-coding,
author = {Tigges, Curt},
title = {Cross-Layer Coding},
year = {2025},
url = {https://github.com/curt-tigges/crosslayer-coding}
}
MIT