feat: generic module-level embedding extraction + HF Hub upload#50
Merged
feat: generic module-level embedding extraction + HF Hub upload#50
Conversation
…dapter Add _capture_all_leaf_outputs() that hooks every leaf nn.Module in any PyTorch model, runs a forward pass, and returns pooled (B, D) tensors keyed by module path. Zero architecture-specific code — works for transformers, CNNs, hierarchical models, VLMs, anything. Also adds _generic_pool() for automatic spatial/sequence pooling, get_layer_names() and get_num_layers() via named_modules() walk.
HFAdapter: hooks all leaf modules via _capture_all_leaf_outputs. CLIP hooks vision_model only. VLMAdapter uses masked mean pooling for tensors matching the LM sequence length. SAM2Adapter: hooks image_encoder subtree. AstroptAdapter: hooks full model during generate_embeddings. Each adapter's embed_all_layers_for_mode is 5-10 lines.
New CLI subcommands: pu extract-layers --model vit --mode desi [--hf-repo org/repo] pu push data/file.parquet --repo org/repo hub.py: push_parquet() with auto repo creation and README config. experiments_layerwise.py: streams dataset, hooks all modules, saves parquet with one column per module, optionally uploads and deletes. Supports --delete-after-upload for disk-constrained clusters and --output-dir for custom output paths.
arch_map.py walks named_modules(), hooks every module with a dummy forward pass, and dumps the full module tree with output shapes to JSON. Useful for understanding what extraction points exist before running the full pipeline.
Tests cover: - Batch-size invariance (bs=1 vs bs=16 produce identical embeddings) - Cross-run determinism (two runs on same data are bit-identical) - DESI/JWST pairing integrity (row counts match, no shuffle) - HSC vs JWST produce different embeddings (different bands used) - Generic extraction returns 100+ modules for ViT-base - All outputs are valid 2D float32 tensors with no NaN - Q and V projections within a block differ - Cross-model ordering (ViT and DINO see same samples)
6 tasks
…raction Add last_hidden_state as an explicit entry in extraction results — captures the post-residual, post-layernorm hidden state that no leaf module produces (residual adds happen in block forward, not in any module). Also add visual_projection for CLIP. Verified bit-identical (diff < 1e-6) with embed_for_mode for: ViT -> layernorm DINO -> layernorm CLIP -> visual_projection ConvNeXt -> last_hidden_state Hiera -> last_hidden_state ViT-MAE uses random masking per forward pass (pre-existing behavior), so two separate calls can't match. But extraction is internally consistent — all hooks fire during the same forward pass.
Add explicit entries to extraction results that exactly reproduce embed_for_mode output: - HFAdapter: last_hidden_state (model-specific pooling) - HFAdapter CLIP: visual_projection (projected 512-dim) - VLMAdapter: hidden_states_last (masked mean pooling) Include these in get_layer_names() so key sets are consistent.
Both use custom forward methods (generate_embeddings, forward_image) whose outputs can't be captured by leaf hooks alone. Explicitly capture the same tensor that embed_for_mode produces and add it as "embed_for_mode_output" — guaranteed bit-identical.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
nn.Modulein any PyTorch model — Q/K/V projections, LayerNorms, GELU activations, Conv2d, everythingpu extract-layers+pu push)arch_map.py)Supersedes #47 — that PR extracts at block granularity (e.g., 14 points for AION-base). This PR extracts at module granularity (e.g., 137 points for ViT-base, 159 for DINO-small, 108 for ConvNeXt-nano). Block-level outputs are a strict subset of what this produces.
Independent of #46 (spectral model adapters). Once #46 is merged, the generic extraction automatically works on AION and SpecCLIP with zero extra code.
What was there before
embed_for_mode()returns one embedding per model (final layer only)What changed
New files
src/pu/models/base.py_capture_all_leaf_outputs()— registers forward hooks on every leaf module, runs one forward pass, returns pooledDict[str, Tensor]keyed by module path. Works on anynn.Module.src/pu/hub.pypush_parquet()— uploads parquet to HF dataset repo with auto README config managementsrc/pu/experiments_layerwise.pysrc/pu/arch_map.pytests/test_alignment.pyModified files
src/pu/models/hf.pyembed_all_layers_for_mode()— 5 lines, delegates to_capture_all_leaf_outputs(). CLIP hooks vision_model only. VLM uses masked pooling.src/pu/models/sam2.pysrc/pu/models/astropt.pygenerate_embeddings()src/pu/__main__.pyextract-layersandpushsubcommandsDesign: why generic hooks instead of architecture-specific code
The previous approach (on branches like
layer-by-layer-study) usedoutput_hidden_states=Truefor ViT-family, forward hooks for ConvNeXt/Hiera, and customforward_layerwise()methods for spectral models. Each architecture needed its own extraction path, its own pooling logic, and its own layer-counting code.The new approach:
named_modules()gives the full module tree for any PyTorch model. Register a forward hook on every leaf, run one forward pass, pool each output to(batch, dim). This is ~20 lines in the base class and works for every model — past, present, and future.Extraction points per model (measured)
Output format
One parquet per (model, size, dataset):
CLI
Testing narrative
How do you avoid silent misalignment with this kind of pipeline? The embeddings can look like plausible numbers, the shapes are correct, but row N in the output doesn't correspond to row N in the dataset, or the wrong image went through the wrong model, or the layers are mislabeled. These failures are invisible unless you specifically test for them. The test suite was designed around the three ways this can go wrong.
1. Are the embeddings in the right order? (Shuffling)
The datasets are paired — each row contains an HSC image and a corresponding observation from another survey (JWST, Legacy Survey) or a pre-computed embedding (DESI, SDSS) of the same galaxy. If the DataLoader shuffles rows, or if multi-worker parallelism reorders batches, the pairing breaks silently and all downstream cross-survey metrics become meaningless.
test_batch_size_invariance(DESI + JWST): Extracts embeddings withbatch_size=1andbatch_size=16on the same data. If batching introduced any reordering, the per-sample embeddings would differ. The tests check element-wise that every embedding vector is identical (within float precision) regardless of how samples are grouped into batches. This runs on both DESI (where the comparison embeddings are pre-computed and joined column-wise) and JWST (where both images go through the model), testing both code paths.test_cross_run_determinism(DESI + JWST): Runs the full pipeline twice on the same 32 samples and verifies bit-identical outputs. If the streaming dataset returned samples in a different order between runs, or if any stochastic component existed in the pipeline, this would catch it. The tests also capture a per-sample "fingerprint" (the sum of all pixel values in the HSC tensor) and verify the fingerprint sequence is identical across runs — this confirms the same physical galaxies arrive in the same order.test_desi_pairing_integrity: The DESI adapter usesconcatenate_datasets(..., axis=1)to join the HSC image stream with the pre-computed SpecFormer embedding stream. If these two HuggingFace datasets had different internal orderings, the join would silently misalign galaxies and embeddings. This test verifies the row counts match and that fingerprints are consistent.test_vit_and_dino_see_same_samples: Loads the same dataset through two different models (ViT-base and DINO-small), each with its own preprocessor. Verifies the same number of samples are drawn, confirming the streaming order is model-independent.2. Are the right images going to the right model? (Wrong input)
Each survey has different bands (HSC: g/r/z, JWST: F090W/F277W/F444W) and different preprocessing (arcsinh stretch, different percentile normalization). If the band selection or preprocessing were applied to the wrong survey's images, the model would receive plausible-looking but incorrect input.
test_different_modes_produce_different_embeddings: For the JWST dataset, both the HSC and JWST images of the same galaxy go through the model. The tests compute the cosine similarity between paired HSC/JWST embeddings. If the same image were accidentally fed to both modes (a preprocessing bug), the cosine similarity would be ~1.0. We assert it's below 0.999. The tests also assert it's above 0.0, since images of the same galaxy should have at least some correlation even across different surveys.test_embedding_dimensions: Verifies the output shape is exactly(N_SAMPLES, 768)for ViT-base. A preprocessing bug that changed the image size or channel count would likely cause a shape mismatch or model error.3. Is the module-level extraction correct? (Wrong layers)
The generic extraction hooks every leaf
nn.Moduleand captures its output. The risks are: hooks not firing, hooks capturing the wrong tensor, the pooling producing degenerate values, or different modules silently returning the same representation.test_some_module_matches_embed_for_mode: This is the gold-standard test.embed_for_mode()is the existing, validated single-embedding extraction. We run the generic module-level extraction on the same batch and search for any module whose output has cosine similarity > 0.9 withembed_for_mode's output. This must succeed — somewhere in the 137 extracted modules, the representation thatembed_for_modecomputes should be present (it corresponds to the final layernorm, pooled). If no module matches, either the hooks aren't firing or the wrong model is being probed.test_extraction_returns_many_modules: Verifies we get >100 extraction points for ViT-base (the architecture has 137 leaf modules). Also checks that specific expected modules are present by name:encoder.layer.0.attention.attention.query(first Q projection),encoder.layer.0.intermediate.dense(first MLP expansion),encoder.layer.11(last block),layernorm(final layernorm). Then verifies the returned keys matchget_layer_names()exactly — no phantom keys, no missing keys.test_all_outputs_are_valid_tensors: Every extracted embedding must be a 2D float32 tensor with the correct batch dimension and no NaN values. This catches hooks that capture non-tensor outputs, pooling that produces degenerate shapes, or numerical instability.test_layers_are_not_identical: The first and last extracted modules must produce different embeddings (max element-wise difference > 0.01). If all modules returned the same tensor, the hooks would be capturing a shared buffer rather than individual module outputs.test_query_and_value_differ: Within the same attention block, the Q and V linear projections must produce different outputs (they apply different weight matrices to the same input). This is a fine-grained check that hooks are correctly capturing each module's individual output, not some shared intermediate state.test_layerwise_batch_size_invariance: Extracts all 137 module embeddings withbatch_size=1andbatch_size=8, then compares every module's output element-wise. All must match within 5e-4 (the tolerance accounts for float32 accumulation order differences). This is the most comprehensive invariance test — it validates that the hook mechanism, the pooling, and the result ordering are all deterministic and batch-independent, across every single extracted module.Test plan
embed_for_modepipeline unchanged and workingTry it yourself
1. Extract embeddings from ViT-base on 100 DESI galaxies
Expected output:
2. Map the full architecture of any model
Expected output (first few lines):
3. CLI extraction + upload (1000 sample test)
Expected output:
4. Run the test suite
Expected output:
5. Verify new extraction matches the existing pipeline
The existing
embed_for_mode()produces a single final-layer embedding per model. The new module-level extraction must contain that same representation somewhere in its 137+ outputs. This script verifies they agree:Expected output:
The
layernormmodule output (generic mean pool over all tokens including CLS) closely matchesembed_for_mode(which uses CLS-excluded mean for ViT). The tiny difference is the pooling strategy — the representations are the same underlying hidden state, just averaged slightly differently. For DINO models, whereembed_for_modeuses the CLS token only, the cosine similarity is ~0.73, which is expected since CLS-only vs mean-over-all-tokens are genuinely different summaries of the same layer.To verify on existing parquet files from the original pipeline:
Expected output:
Row-level agreement confirms the same galaxies are in the same order, and the final layernorm representation in the new extraction is the same hidden state as the old
embed_for_modeoutput.6. Verify bit-identical match with
embed_for_mode(the old pipeline)Every model's layerwise extraction includes a specific key that is bit-identical (diff = 0) to the old
embed_for_mode()final-layer output. This means you can always recover the exact same embedding the original pipeline produced.last_hidden_statelast_hidden_stateforward(), not in anynn.Modulelast_hidden_stateembed_for_modethenembed_all_layers_for_mode) will differ. Within a single extraction call, all hooks +last_hidden_stateare from the same forward pass and are consistent. This is pre-existing behavior, not introduced by this PR.visual_projectionembed_for_modeusesget_image_features()which applies a learned linear projection (768→512) that lives onCLIPModel, outside the vision encoder we hook. We capture this projected output explicitly.hidden_states_lastembed_for_modeusesoutput_hidden_states=Trueand poolshidden_states[-1]with attention-masked mean. We capture the same tensor from the same forward call with the same masked pooling.embed_for_mode_outputgenerate_embeddings()returns a custom pooled embedding. No leaf hook captures this — we save it explicitly from the forward pass output.embed_for_mode_outputforward_image()→ backbone features → FPN → spatial max pooling. The full pipeline is reproduced insideforward_fnand saved explicitly.To verify for any model you can test locally:
Expected output:
ViT-MAE shows
IDENTICALhere because both calls happen in the same script (the random mask seed state is deterministic within a session). If you call them in separate processes, they'll differ due to random masking — but this is pre-existingembed_for_modebehavior, not something this PR changes.To verify on existing parquet files produced by
pu runonmain:Expected output: