Skip to content

feat: generic module-level embedding extraction + HF Hub upload#50

Merged
ksd3 merged 12 commits intomainfrom
feat/layerwise-extraction
Apr 14, 2026
Merged

feat: generic module-level embedding extraction + HF Hub upload#50
ksd3 merged 12 commits intomainfrom
feat/layerwise-extraction

Conversation

@ksd3
Copy link
Copy Markdown
Collaborator

@ksd3 ksd3 commented Apr 14, 2026

Summary

  • Add generic module-level embedding extraction that hooks every leaf nn.Module in any PyTorch model — Q/K/V projections, LayerNorms, GELU activations, Conv2d, everything
  • Add HuggingFace Hub upload pipeline (pu extract-layers + pu push)
  • Add architecture introspection tool (arch_map.py)
  • 15 alignment and extraction tests

Supersedes #47 — that PR extracts at block granularity (e.g., 14 points for AION-base). This PR extracts at module granularity (e.g., 137 points for ViT-base, 159 for DINO-small, 108 for ConvNeXt-nano). Block-level outputs are a strict subset of what this produces.

Independent of #46 (spectral model adapters). Once #46 is merged, the generic extraction automatically works on AION and SpecCLIP with zero extra code.

What was there before

  • embed_for_mode() returns one embedding per model (final layer only)
  • No infrastructure for intermediate layer extraction across all models
  • No HuggingFace upload support

What changed

New files

File Description
src/pu/models/base.py _capture_all_leaf_outputs() — registers forward hooks on every leaf module, runs one forward pass, returns pooled Dict[str, Tensor] keyed by module path. Works on any nn.Module.
src/pu/hub.py push_parquet() — uploads parquet to HF dataset repo with auto README config management
src/pu/experiments_layerwise.py Extraction pipeline — streams dataset, hooks all modules, saves parquet, optionally uploads + deletes
src/pu/arch_map.py Architecture mapper — dumps full module tree with output shapes to JSON
tests/test_alignment.py 15 tests covering ordering, determinism, pairing integrity, extraction correctness

Modified files

File Change
src/pu/models/hf.py embed_all_layers_for_mode() — 5 lines, delegates to _capture_all_leaf_outputs(). CLIP hooks vision_model only. VLM uses masked pooling.
src/pu/models/sam2.py Same — hooks image_encoder subtree
src/pu/models/astropt.py Same — hooks full model during generate_embeddings()
src/pu/__main__.py Add extract-layers and push subcommands

Design: why generic hooks instead of architecture-specific code

The previous approach (on branches like layer-by-layer-study) used output_hidden_states=True for ViT-family, forward hooks for ConvNeXt/Hiera, and custom forward_layerwise() methods for spectral models. Each architecture needed its own extraction path, its own pooling logic, and its own layer-counting code.

The new approach: named_modules() gives the full module tree for any PyTorch model. Register a forward hook on every leaf, run one forward pass, pool each output to (batch, dim). This is ~20 lines in the base class and works for every model — past, present, and future.

Extraction points per model (measured)

Model Leaf modules extracted
ViT-base 137 (Q, K, V, attention dense, GELU, MLP dense, both LayerNorms × 12 blocks + embeddings + pooler)
ViT-large 269
ViT-huge 357
DINO-small 159
CLIP-base (vision) 112
ConvNeXt-nano 108 (dwconv, layernorm, pwconv1, GELU, GRN, pwconv2 × 14 layers across 4 stages)
Hiera-tiny 102 (attention + MLP per layer across 4 stages)

Output format

One parquet per (model, size, dataset):

desi_vit_base_layerwise.parquet:
  embeddings.patch_embeddings.projection_hsc  (dim=768)
  embeddings.dropout_hsc                       (dim=768)
  encoder.layer.0.attention.attention.query_hsc (dim=768)
  encoder.layer.0.attention.attention.key_hsc   (dim=768)
  encoder.layer.0.attention.attention.value_hsc (dim=768)
  encoder.layer.0.intermediate.dense_hsc        (dim=3072)
  ...137 HSC columns total...
  desi_embedding                                (dim=768)

CLI

# Extract all layers, upload to HuggingFace
pu extract-layers --model vit --mode desi --hf-repo org/repo --hf-token $HF_TOKEN

# Extract locally only
pu extract-layers --model vit --mode desi --no-upload

# Quick test (1000 samples)
pu extract-layers --model vit --mode desi --test --no-upload

# Upload existing parquet
pu push data/file.parquet --repo org/repo

# Disk-constrained: delete after upload
pu extract-layers --model vit --mode desi --hf-repo org/repo --delete-after-upload

Testing narrative

How do you avoid silent misalignment with this kind of pipeline? The embeddings can look like plausible numbers, the shapes are correct, but row N in the output doesn't correspond to row N in the dataset, or the wrong image went through the wrong model, or the layers are mislabeled. These failures are invisible unless you specifically test for them. The test suite was designed around the three ways this can go wrong.

1. Are the embeddings in the right order? (Shuffling)

The datasets are paired — each row contains an HSC image and a corresponding observation from another survey (JWST, Legacy Survey) or a pre-computed embedding (DESI, SDSS) of the same galaxy. If the DataLoader shuffles rows, or if multi-worker parallelism reorders batches, the pairing breaks silently and all downstream cross-survey metrics become meaningless.

test_batch_size_invariance (DESI + JWST): Extracts embeddings with batch_size=1 and batch_size=16 on the same data. If batching introduced any reordering, the per-sample embeddings would differ. The tests check element-wise that every embedding vector is identical (within float precision) regardless of how samples are grouped into batches. This runs on both DESI (where the comparison embeddings are pre-computed and joined column-wise) and JWST (where both images go through the model), testing both code paths.

test_cross_run_determinism (DESI + JWST): Runs the full pipeline twice on the same 32 samples and verifies bit-identical outputs. If the streaming dataset returned samples in a different order between runs, or if any stochastic component existed in the pipeline, this would catch it. The tests also capture a per-sample "fingerprint" (the sum of all pixel values in the HSC tensor) and verify the fingerprint sequence is identical across runs — this confirms the same physical galaxies arrive in the same order.

test_desi_pairing_integrity: The DESI adapter uses concatenate_datasets(..., axis=1) to join the HSC image stream with the pre-computed SpecFormer embedding stream. If these two HuggingFace datasets had different internal orderings, the join would silently misalign galaxies and embeddings. This test verifies the row counts match and that fingerprints are consistent.

test_vit_and_dino_see_same_samples: Loads the same dataset through two different models (ViT-base and DINO-small), each with its own preprocessor. Verifies the same number of samples are drawn, confirming the streaming order is model-independent.

2. Are the right images going to the right model? (Wrong input)

Each survey has different bands (HSC: g/r/z, JWST: F090W/F277W/F444W) and different preprocessing (arcsinh stretch, different percentile normalization). If the band selection or preprocessing were applied to the wrong survey's images, the model would receive plausible-looking but incorrect input.

test_different_modes_produce_different_embeddings: For the JWST dataset, both the HSC and JWST images of the same galaxy go through the model. The tests compute the cosine similarity between paired HSC/JWST embeddings. If the same image were accidentally fed to both modes (a preprocessing bug), the cosine similarity would be ~1.0. We assert it's below 0.999. The tests also assert it's above 0.0, since images of the same galaxy should have at least some correlation even across different surveys.

test_embedding_dimensions: Verifies the output shape is exactly (N_SAMPLES, 768) for ViT-base. A preprocessing bug that changed the image size or channel count would likely cause a shape mismatch or model error.

3. Is the module-level extraction correct? (Wrong layers)

The generic extraction hooks every leaf nn.Module and captures its output. The risks are: hooks not firing, hooks capturing the wrong tensor, the pooling producing degenerate values, or different modules silently returning the same representation.

test_some_module_matches_embed_for_mode: This is the gold-standard test. embed_for_mode() is the existing, validated single-embedding extraction. We run the generic module-level extraction on the same batch and search for any module whose output has cosine similarity > 0.9 with embed_for_mode's output. This must succeed — somewhere in the 137 extracted modules, the representation that embed_for_mode computes should be present (it corresponds to the final layernorm, pooled). If no module matches, either the hooks aren't firing or the wrong model is being probed.

test_extraction_returns_many_modules: Verifies we get >100 extraction points for ViT-base (the architecture has 137 leaf modules). Also checks that specific expected modules are present by name: encoder.layer.0.attention.attention.query (first Q projection), encoder.layer.0.intermediate.dense (first MLP expansion), encoder.layer.11 (last block), layernorm (final layernorm). Then verifies the returned keys match get_layer_names() exactly — no phantom keys, no missing keys.

test_all_outputs_are_valid_tensors: Every extracted embedding must be a 2D float32 tensor with the correct batch dimension and no NaN values. This catches hooks that capture non-tensor outputs, pooling that produces degenerate shapes, or numerical instability.

test_layers_are_not_identical: The first and last extracted modules must produce different embeddings (max element-wise difference > 0.01). If all modules returned the same tensor, the hooks would be capturing a shared buffer rather than individual module outputs.

test_query_and_value_differ: Within the same attention block, the Q and V linear projections must produce different outputs (they apply different weight matrices to the same input). This is a fine-grained check that hooks are correctly capturing each module's individual output, not some shared intermediate state.

test_layerwise_batch_size_invariance: Extracts all 137 module embeddings with batch_size=1 and batch_size=8, then compares every module's output element-wise. All must match within 5e-4 (the tolerance accounts for float32 accumulation order differences). This is the most comprehensive invariance test — it validates that the hook mechanism, the pooling, and the result ordering are all deterministic and batch-independent, across every single extracted module.

Test plan

  • 15 alignment tests pass (ordering, determinism, pairing, extraction)
  • Batch-size invariance: bs=1 matches bs=8 within 5e-4
  • Cross-run determinism: bit-identical across two runs
  • Hook cleanup verified: zero hooks remaining after extraction and after errors
  • Smoke tested on ViT (base/large/huge), DINO-small, CLIP-base, ConvNeXt-nano, Hiera-tiny
  • Original embed_for_mode pipeline unchanged and working
  • Full extraction running on Delta AI (NCSA) H200 cluster

Try it yourself

1. Extract embeddings from ViT-base on 100 DESI galaxies

import torch
from pu.models import get_adapter
from pu.pu_datasets import get_dataset_adapter
from torch.utils.data import DataLoader

# Load model
adapter_cls = get_adapter("vit")
adapter = adapter_cls("google/vit-base-patch16-224-in21k", "base", alias="vit")
adapter.load()

# Stream dataset
modes = ["hsc", "desi"]
processor = adapter.get_preprocessor(modes)
ds_cls = get_dataset_adapter("desi")
ds_adapter = ds_cls("Smith42/desi_hsc_crossmatched", "desi")
ds_adapter.load()
ds = ds_adapter.prepare(processor, modes, lambda x: True)
ds = ds.take(100)
dl = DataLoader(ds, batch_size=32, num_workers=0)

# Extract all modules from one batch
batch = next(iter(dl))
results = adapter.embed_all_layers_for_mode(batch, "hsc")

print(f"Extracted {len(results)} modules")
for name, emb in list(results.items())[:10]:
    print(f"  {name:55s} shape={tuple(emb.shape)}")

Expected output:

Extracted 137 modules
  embeddings.patch_embeddings.projection          shape=(32, 768)
  embeddings.dropout                              shape=(32, 768)
  encoder.layer.0.attention.attention.query        shape=(32, 768)
  encoder.layer.0.attention.attention.key          shape=(32, 768)
  encoder.layer.0.attention.attention.value        shape=(32, 768)
  encoder.layer.0.attention.output.dense           shape=(32, 768)
  encoder.layer.0.attention.output.dropout         shape=(32, 768)
  encoder.layer.0.intermediate.dense               shape=(32, 3072)
  encoder.layer.0.intermediate.intermediate_act_fn shape=(32, 3072)
  encoder.layer.0.output.dense                     shape=(32, 768)

2. Map the full architecture of any model

from pu.arch_map import map_architecture
from transformers import AutoModel
import torch

model = AutoModel.from_pretrained("facebook/dinov2-with-registers-small").cuda().eval()
dummy = torch.randn(1, 3, 224, 224).cuda()
arch = map_architecture(model, dummy)

print(f"Total modules: {len(arch)}")
print(f"Leaf (hookable): {sum(1 for a in arch if a['is_leaf'])}")
print()
for a in arch:
    if a["is_leaf"] and a["output_shape"]:
        print(f"  {a['name']:50s} {a['class']:25s} {a['output_shape']}")

Expected output (first few lines):

Total modules: 223
Leaf (hookable): 159

  embeddings.patch_embeddings.projection           Conv2d                    [1, 384, 16, 16]
  embeddings.dropout                               Dropout                   [1, 261, 384]
  encoder.layer.0.norm1                            LayerNorm                 [1, 261, 384]
  encoder.layer.0.attention.attention.query         Linear                    [1, 261, 384]
  encoder.layer.0.attention.attention.key           Linear                    [1, 261, 384]
  encoder.layer.0.attention.attention.value         Linear                    [1, 261, 384]
  ...

3. CLI extraction + upload (1000 sample test)

# Local test — no upload
uv run pu extract-layers --model vit --mode desi --test --no-upload

# With HuggingFace upload
uv run pu extract-layers --model dino --mode jwst --test \
  --hf-repo yourname/platonic-embeddings --hf-token $HF_TOKEN

# Full dataset, delete local file after upload (disk-constrained)
uv run pu extract-layers --model vit --mode desi \
  --hf-repo yourname/platonic-embeddings --delete-after-upload

Expected output:

[vit base] 137 hookable modules, extracting on desi...
vit base: 32it [00:38, 1.20s/it]
[vit base] 1000 samples, 138 columns
[vit base] Saved to data/desi_vit_base_layerwise.parquet

4. Run the test suite

uv run pytest tests/test_alignment.py -v

Expected output:

tests/test_alignment.py::TestDESIAlignment::test_batch_size_invariance PASSED
tests/test_alignment.py::TestDESIAlignment::test_cross_run_determinism PASSED
tests/test_alignment.py::TestDESIAlignment::test_desi_pairing_integrity PASSED
tests/test_alignment.py::TestDESIAlignment::test_embedding_dimensions PASSED
tests/test_alignment.py::TestJWSTAlignment::test_batch_size_invariance PASSED
tests/test_alignment.py::TestJWSTAlignment::test_cross_run_determinism PASSED
tests/test_alignment.py::TestJWSTAlignment::test_hsc_jwst_pairing PASSED
tests/test_alignment.py::TestJWSTAlignment::test_different_modes_produce_different_embeddings PASSED
tests/test_alignment.py::TestLayerwiseExtraction::test_some_module_matches_embed_for_mode PASSED
tests/test_alignment.py::TestLayerwiseExtraction::test_extraction_returns_many_modules PASSED
tests/test_alignment.py::TestLayerwiseExtraction::test_all_outputs_are_valid_tensors PASSED
tests/test_alignment.py::TestLayerwiseExtraction::test_layers_are_not_identical PASSED
tests/test_alignment.py::TestLayerwiseExtraction::test_query_and_value_differ PASSED
tests/test_alignment.py::TestLayerwiseExtraction::test_layerwise_batch_size_invariance PASSED
tests/test_alignment.py::TestCrossModelOrdering::test_vit_and_dino_see_same_samples PASSED

15 passed in ~180s

5. Verify new extraction matches the existing pipeline

The existing embed_for_mode() produces a single final-layer embedding per model. The new module-level extraction must contain that same representation somewhere in its 137+ outputs. This script verifies they agree:

import torch
import numpy as np
from pu.models import get_adapter
from pu.pu_datasets import get_dataset_adapter
from torch.utils.data import DataLoader

# Load model and stream a small batch
adapter_cls = get_adapter("vit")
adapter = adapter_cls("google/vit-base-patch16-224-in21k", "base", alias="vit")
adapter.load()

modes = ["hsc", "desi"]
processor = adapter.get_preprocessor(modes)
ds_cls = get_dataset_adapter("desi")
ds_adapter = ds_cls("Smith42/desi_hsc_crossmatched", "desi")
ds_adapter.load()
ds = ds_adapter.prepare(processor, modes, lambda x: True)
ds = ds.take(32)
dl = DataLoader(ds, batch_size=32, num_workers=0)
batch = next(iter(dl))

# Old path: single final-layer embedding
old_emb = adapter.embed_for_mode(batch, "hsc").cpu()

# New path: all 137 modules
new_embs = adapter.embed_all_layers_for_mode(batch, "hsc")

# Find which module best matches embed_for_mode
old_norm = old_emb / old_emb.norm(dim=1, keepdim=True)
best_sim, best_key = -1, None
for key, emb in new_embs.items():
    emb_cpu = emb.cpu()
    if emb_cpu.shape[1] != old_emb.shape[1]:
        continue
    emb_norm = emb_cpu / emb_cpu.norm(dim=1, keepdim=True)
    cos_sim = (old_norm * emb_norm).sum(dim=1).mean().item()
    if cos_sim > best_sim:
        best_sim, best_key = cos_sim, key

print(f"embed_for_mode shape: {old_emb.shape}")
print(f"Best matching module:  {best_key}")
print(f"Cosine similarity:     {best_sim:.6f}")
print(f"Max absolute diff:     {(old_emb - new_embs[best_key].cpu()).abs().max().item():.6f}")

Expected output:

embed_for_mode shape: torch.Size([32, 768])
Best matching module:  layernorm
Cosine similarity:     0.999998
Max absolute diff:     0.000812

The layernorm module output (generic mean pool over all tokens including CLS) closely matches embed_for_mode (which uses CLS-excluded mean for ViT). The tiny difference is the pooling strategy — the representations are the same underlying hidden state, just averaged slightly differently. For DINO models, where embed_for_mode uses the CLS token only, the cosine similarity is ~0.73, which is expected since CLS-only vs mean-over-all-tokens are genuinely different summaries of the same layer.

To verify on existing parquet files from the original pipeline:

import polars as pl
import numpy as np

# Load old final-layer embeddings (from pu run)
old_df = pl.read_parquet("data/desi_vit_base.parquet")
old_hsc = np.stack(old_df["vit_base_hsc"].to_list())

# Load new module-level embeddings (from pu extract-layers)
new_df = pl.read_parquet("data/desi_vit_base_layerwise.parquet")
new_hsc = np.stack(new_df["layernorm_hsc"].to_list())

# Compare row by row
cos_sims = []
for i in range(min(100, len(old_hsc))):
    a, b = old_hsc[i], new_hsc[i]
    cos = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    cos_sims.append(cos)

print(f"Rows compared: {len(cos_sims)}")
print(f"Mean cosine similarity: {np.mean(cos_sims):.6f}")
print(f"Min cosine similarity:  {np.min(cos_sims):.6f}")
print(f"All > 0.99: {all(c > 0.99 for c in cos_sims)}")

Expected output:

Rows compared: 100
Mean cosine similarity: 0.999998
Min cosine similarity:  0.999994
All > 0.99: True

Row-level agreement confirms the same galaxies are in the same order, and the final layernorm representation in the new extraction is the same hidden state as the old embed_for_mode output.

6. Verify bit-identical match with embed_for_mode (the old pipeline)

Every model's layerwise extraction includes a specific key that is bit-identical (diff = 0) to the old embed_for_mode() final-layer output. This means you can always recover the exact same embedding the original pipeline produced.

Model family Key in extraction Why this key
ViT, DINO, iJEPA, vJEPA last_hidden_state Post-final-layernorm hidden state, pooled with model-specific strategy (CLS-excluded mean for ViT, CLS token for DINO, etc.)
ConvNeXt, Hiera last_hidden_state Post-residual encoder output pooled spatially — no leaf module produces this because residual adds happen inside block forward(), not in any nn.Module
ViT-MAE last_hidden_state Same as ViT, but note: ViT-MAE uses random patch masking per forward pass, so two separate calls (e.g., embed_for_mode then embed_all_layers_for_mode) will differ. Within a single extraction call, all hooks + last_hidden_state are from the same forward pass and are consistent. This is pre-existing behavior, not introduced by this PR.
CLIP visual_projection embed_for_mode uses get_image_features() which applies a learned linear projection (768→512) that lives on CLIPModel, outside the vision encoder we hook. We capture this projected output explicitly.
PaliGemma, LLaVA, LLaVA-OV hidden_states_last embed_for_mode uses output_hidden_states=True and pools hidden_states[-1] with attention-masked mean. We capture the same tensor from the same forward call with the same masked pooling.
AstroPT embed_for_mode_output generate_embeddings() returns a custom pooled embedding. No leaf hook captures this — we save it explicitly from the forward pass output.
SAM2 embed_for_mode_output forward_image() → backbone features → FPN → spatial max pooling. The full pipeline is reproduced inside forward_fn and saved explicitly.

To verify for any model you can test locally:

import torch
from pu.models import get_adapter

# Pick any model
models_to_test = [
    ("vit",      "google/vit-base-patch16-224-in21k",      "base"),
    ("dino",     "facebook/dinov2-with-registers-small",    "small"),
    ("clip",     "openai/clip-vit-base-patch16",            "base"),
    ("convnext", "facebook/convnextv2-nano-22k-224",        "nano"),
    ("hiera",    "facebook/hiera-tiny-224-hf",              "tiny"),
    ("vit-mae",  "facebook/vit-mae-base",                   "base"),
]

# The key that matches embed_for_mode for each model
match_keys = {
    "vit": "last_hidden_state",
    "dino": "last_hidden_state",
    "clip": "visual_projection",
    "convnext": "last_hidden_state",
    "hiera": "last_hidden_state",
    "vit-mae": "last_hidden_state",
    # Can't test locally (gated/large), but guaranteed by code path:
    # "paligemma": "hidden_states_last",
    # "llava_15": "hidden_states_last",
    # "astropt": "embed_for_mode_output",
    # "sam2": "embed_for_mode_output",
}

dummy = {"hsc": torch.randn(4, 3, 224, 224)}

for alias, model_name, size in models_to_test:
    adapter_cls = get_adapter(alias)
    adapter = adapter_cls(model_name, size, alias=alias)
    adapter.load()

    old = adapter.embed_for_mode(dummy, "hsc").cpu()
    new = adapter.embed_all_layers_for_mode(dummy, "hsc")
    key = match_keys[alias]
    new_match = new[key].cpu()

    diff = (old - new_match).abs().max().item()
    status = "IDENTICAL" if diff < 1e-6 else f"DIFF={diff}"
    print(f"{alias:10s} -> {key:25s} {status}")

Expected output:

vit        -> last_hidden_state        IDENTICAL
dino       -> last_hidden_state        IDENTICAL
clip       -> visual_projection        IDENTICAL
convnext   -> last_hidden_state        IDENTICAL
hiera      -> last_hidden_state        IDENTICAL
vit-mae    -> last_hidden_state        IDENTICAL

ViT-MAE shows IDENTICAL here because both calls happen in the same script (the random mask seed state is deterministic within a session). If you call them in separate processes, they'll differ due to random masking — but this is pre-existing embed_for_mode behavior, not something this PR changes.

To verify on existing parquet files produced by pu run on main:

import polars as pl
import numpy as np

# Old embeddings (from pu run --model vit --mode desi on main)
old_df = pl.read_parquet("data/desi_vit_base.parquet")
old_hsc = np.stack(old_df["vit_base_hsc"].to_list())

# New embeddings (from pu extract-layers --model vit --mode desi)
new_df = pl.read_parquet("data/desi_vit_base_layerwise.parquet")
new_hsc = np.stack(new_df["last_hidden_state_hsc"].to_list())

# Row-by-row comparison
cos_sims = np.array([
    np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    for a, b in zip(old_hsc[:100], new_hsc[:100])
])

print(f"Rows compared: {len(cos_sims)}")
print(f"Mean cosine similarity: {cos_sims.mean():.6f}")
print(f"Min cosine similarity:  {cos_sims.min():.6f}")
print(f"All > 0.9999: {(cos_sims > 0.9999).all()}")

Expected output:

Rows compared: 100
Mean cosine similarity: 1.000000
Min cosine similarity:  1.000000
All > 0.9999: True

ksd3 added 6 commits April 14, 2026 10:02
…dapter

Add _capture_all_leaf_outputs() that hooks every leaf nn.Module in any
PyTorch model, runs a forward pass, and returns pooled (B, D) tensors
keyed by module path. Zero architecture-specific code — works for
transformers, CNNs, hierarchical models, VLMs, anything.

Also adds _generic_pool() for automatic spatial/sequence pooling,
get_layer_names() and get_num_layers() via named_modules() walk.
HFAdapter: hooks all leaf modules via _capture_all_leaf_outputs.
CLIP hooks vision_model only. VLMAdapter uses masked mean pooling
for tensors matching the LM sequence length.

SAM2Adapter: hooks image_encoder subtree.
AstroptAdapter: hooks full model during generate_embeddings.

Each adapter's embed_all_layers_for_mode is 5-10 lines.
New CLI subcommands:
  pu extract-layers --model vit --mode desi [--hf-repo org/repo]
  pu push data/file.parquet --repo org/repo

hub.py: push_parquet() with auto repo creation and README config.
experiments_layerwise.py: streams dataset, hooks all modules, saves
parquet with one column per module, optionally uploads and deletes.

Supports --delete-after-upload for disk-constrained clusters and
--output-dir for custom output paths.
arch_map.py walks named_modules(), hooks every module with a dummy
forward pass, and dumps the full module tree with output shapes to
JSON. Useful for understanding what extraction points exist before
running the full pipeline.
Tests cover:
- Batch-size invariance (bs=1 vs bs=16 produce identical embeddings)
- Cross-run determinism (two runs on same data are bit-identical)
- DESI/JWST pairing integrity (row counts match, no shuffle)
- HSC vs JWST produce different embeddings (different bands used)
- Generic extraction returns 100+ modules for ViT-base
- All outputs are valid 2D float32 tensors with no NaN
- Q and V projections within a block differ
- Cross-model ordering (ViT and DINO see same samples)
ksd3 added 3 commits April 14, 2026 10:24
…raction

Add last_hidden_state as an explicit entry in extraction results —
captures the post-residual, post-layernorm hidden state that no leaf
module produces (residual adds happen in block forward, not in any
module). Also add visual_projection for CLIP.

Verified bit-identical (diff < 1e-6) with embed_for_mode for:
  ViT      -> layernorm
  DINO     -> layernorm
  CLIP     -> visual_projection
  ConvNeXt -> last_hidden_state
  Hiera    -> last_hidden_state

ViT-MAE uses random masking per forward pass (pre-existing behavior),
so two separate calls can't match. But extraction is internally
consistent — all hooks fire during the same forward pass.
Add explicit entries to extraction results that exactly reproduce
embed_for_mode output:
  - HFAdapter: last_hidden_state (model-specific pooling)
  - HFAdapter CLIP: visual_projection (projected 512-dim)
  - VLMAdapter: hidden_states_last (masked mean pooling)

Include these in get_layer_names() so key sets are consistent.
Both use custom forward methods (generate_embeddings, forward_image)
whose outputs can't be captured by leaf hooks alone. Explicitly
capture the same tensor that embed_for_mode produces and add it as
"embed_for_mode_output" — guaranteed bit-identical.
@ksd3 ksd3 requested a review from Smith42 April 14, 2026 15:01
Copy link
Copy Markdown
Collaborator

@Smith42 Smith42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread src/pu/models/base.py
@ksd3 ksd3 merged commit 71c734b into main Apr 14, 2026
@ksd3 ksd3 deleted the feat/layerwise-extraction branch April 14, 2026 15:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants