Skip to content

Commit 71c734b

Browse files
authored
Merge pull request #50 from UniverseTBD/feat/layerwise-extraction
feat: generic module-level embedding extraction + HF Hub upload
2 parents 9e55ce2 + efcbd9b commit 71c734b

10 files changed

Lines changed: 1572 additions & 106 deletions

File tree

src/pu/__main__.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,29 @@ def main():
4747
parser_percentiles.add_argument("--resize-mode", type=str, default="match", choices=["match", "fill"], help="Resize strategy (default: match).")
4848
parser_percentiles.add_argument("--output", type=str, default="data/percentiles.json", help="Output JSON path (default: data/percentiles.json).")
4949

50+
# Subparser for layerwise extraction
51+
parser_extract = subparsers.add_parser("extract-layers", help="Extract embeddings from all layers of a model.")
52+
parser_extract.add_argument("--model", required=True, help="Model to extract (e.g., 'vit', 'dino').")
53+
parser_extract.add_argument("--mode", required=True, help="Dataset mode (e.g., 'jwst', 'desi').")
54+
parser_extract.add_argument("--batch-size", type=int, default=64, help="Batch size (default: 64, lower than run due to layerwise memory).")
55+
parser_extract.add_argument("--num-workers", type=int, default=0, help="Number of data loader workers.")
56+
parser_extract.add_argument("--no-resize", dest="resize", action="store_false", help="Disable galaxy resizing.")
57+
parser_extract.add_argument("--resize-mode", type=str, default="match", choices=["match", "fill"], help="Resize strategy (default: match).")
58+
parser_extract.add_argument("--test", action="store_true", help="Quick test run using only 1000 samples.")
59+
parser_extract.add_argument("--test-10k", action="store_true", help="Test run using only 10000 samples.")
60+
parser_extract.add_argument("--hf-repo", type=str, default=os.environ.get("PU_HF_REPO"), help="HuggingFace dataset repo ID for upload. Default: $PU_HF_REPO.")
61+
parser_extract.add_argument("--hf-token", type=str, default=None, help="HuggingFace token. Default: $HF_TOKEN env var.")
62+
parser_extract.add_argument("--no-upload", action="store_true", help="Disable HuggingFace upload (upload is on by default when --hf-repo is set).")
63+
parser_extract.add_argument("--delete-after-upload", action="store_true", help="Delete local parquet file after successful upload to HuggingFace. Saves disk space.")
64+
parser_extract.add_argument("--output-dir", type=str, default="data", help="Directory to write parquet files (default: data/).")
65+
66+
# Subparser for pushing parquet files to HuggingFace Hub
67+
parser_push = subparsers.add_parser("push", help="Upload parquet files to a HuggingFace dataset repo.")
68+
parser_push.add_argument("file", nargs="?", help="Path to a .parquet file to upload.")
69+
parser_push.add_argument("--all", action="store_true", help="Upload all .parquet files in data/.")
70+
parser_push.add_argument("--repo", required=True, help="HuggingFace dataset repo ID (e.g., 'org/dataset-name').")
71+
parser_push.add_argument("--token", type=str, default=None, help="HuggingFace token. Default: $HF_TOKEN env var.")
72+
5073
# Subparser for benchmarking performance optimizations
5174
parser_benchmark = subparsers.add_parser("benchmark", help="Run performance benchmarks with optimization flags.")
5275
parser_benchmark.add_argument("--model", required=True, help="Model to benchmark (e.g., 'vit', 'dino').")
@@ -162,6 +185,33 @@ def main():
162185
resize_mode=args.resize_mode,
163186
output_path=args.output,
164187
)
188+
elif args.command == "extract-layers":
189+
from pu.experiments_layerwise import extract_all_layers
190+
if args.mode in PAIRED_MODES and args.num_workers > 0:
191+
print(f"Warning: Setting num_workers=0 for paired mode '{args.mode}' because multiple workers can change draw order and break pairing.")
192+
args.num_workers = 0
193+
extract_all_layers(
194+
args.model,
195+
args.mode,
196+
batch_size=args.batch_size,
197+
num_workers=args.num_workers,
198+
max_samples=1000 if args.test else 10000 if args.test_10k else None,
199+
resize=args.resize,
200+
resize_mode=args.resize_mode,
201+
output_dir=args.output_dir,
202+
hf_repo=args.hf_repo,
203+
hf_token=args.hf_token,
204+
upload=not args.no_upload,
205+
delete_after_upload=args.delete_after_upload,
206+
)
207+
elif args.command == "push":
208+
from pu.hub import push_parquet, push_all
209+
if args.all:
210+
push_all("data", args.repo, token=args.token)
211+
elif args.file:
212+
push_parquet(args.file, args.repo, token=args.token)
213+
else:
214+
parser.error("Specify a file or --all")
165215
elif args.command == "benchmark":
166216
from pu.benchmark import run_benchmark, BenchmarkConfig
167217

src/pu/arch_map.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""Map the full module tree of any PyTorch model to a machine-readable JSON.
2+
3+
Every nn.Module in the tree is a valid hook point. This module:
4+
1. Walks the full named_modules() graph
5+
2. Probes each module with a dummy forward to get output shapes
6+
3. Dumps a JSON file describing every extractable point
7+
8+
Usage:
9+
from pu.arch_map import map_architecture
10+
arch = map_architecture(model, dummy_input)
11+
# arch is a list of dicts, each with:
12+
# name, class, output_shape, num_params, depth, is_leaf
13+
"""
14+
15+
import json
16+
from pathlib import Path
17+
18+
import torch
19+
import torch.nn as nn
20+
21+
22+
def map_architecture(model, dummy_input, device="cuda"):
23+
"""Walk the full module tree and probe output shapes.
24+
25+
Args:
26+
model: Any nn.Module (already on device, in eval mode).
27+
dummy_input: A tensor that can be passed to model(dummy_input).
28+
For CLIP, pass pixel_values; for VLMs, pass a dict.
29+
device: Device string.
30+
31+
Returns:
32+
List of dicts, one per named module (excluding root).
33+
"""
34+
shapes = {}
35+
hooks = []
36+
37+
def _make_hook(name):
38+
def hook(module, input, output):
39+
if isinstance(output, tuple):
40+
t = output[0]
41+
elif isinstance(output, dict):
42+
# Some modules return dicts (e.g., BaseModelOutput)
43+
t = next(iter(output.values())) if output else None
44+
else:
45+
t = output
46+
if isinstance(t, torch.Tensor):
47+
shapes[name] = list(t.shape)
48+
else:
49+
shapes[name] = None
50+
return hook
51+
52+
# Register hooks on every module
53+
for name, mod in model.named_modules():
54+
if name:
55+
h = mod.register_forward_hook(_make_hook(name))
56+
hooks.append(h)
57+
58+
# Forward pass to capture all shapes
59+
with torch.no_grad():
60+
try:
61+
if isinstance(dummy_input, dict):
62+
model(**dummy_input)
63+
else:
64+
model(dummy_input)
65+
except Exception as e:
66+
print(f"Warning: forward pass raised {e.__class__.__name__}: {e}")
67+
68+
# Remove all hooks
69+
for h in hooks:
70+
h.remove()
71+
72+
# Build architecture map
73+
arch = []
74+
for name, mod in model.named_modules():
75+
if not name:
76+
continue
77+
# Count depth by dots
78+
depth = name.count(".") + 1
79+
# Is leaf = has no children
80+
is_leaf = len(list(mod.children())) == 0
81+
# Parameter count (non-recursive to avoid double counting)
82+
num_params = sum(p.numel() for p in mod.parameters(recurse=False))
83+
84+
entry = {
85+
"name": name,
86+
"class": mod.__class__.__name__,
87+
"output_shape": shapes.get(name),
88+
"num_params": num_params,
89+
"depth": depth,
90+
"is_leaf": is_leaf,
91+
}
92+
arch.append(entry)
93+
94+
return arch
95+
96+
97+
def map_all_models(output_dir="data/architectures", batch_size=2, image_size=224):
98+
"""Map architectures for all registered models and save as JSON files.
99+
100+
Loads each model, runs a dummy forward, and saves the full module tree.
101+
"""
102+
from pu.models import get_adapter
103+
from pu.experiments_layerwise import MODEL_MAP
104+
105+
output_dir = Path(output_dir)
106+
output_dir.mkdir(parents=True, exist_ok=True)
107+
108+
dummy_img = torch.randn(batch_size, 3, image_size, image_size)
109+
110+
for alias, (sizes, model_names) in MODEL_MAP.items():
111+
# Just map the first (smallest) size
112+
size, model_name = sizes[0], model_names[0]
113+
out_path = output_dir / f"{alias}_{size}.json"
114+
115+
if out_path.exists():
116+
print(f"[skip] {out_path} already exists")
117+
continue
118+
119+
print(f"\n[{alias} {size}] Loading {model_name}...")
120+
try:
121+
adapter_cls = get_adapter(alias)
122+
adapter = adapter_cls(model_name, size, alias=alias)
123+
adapter.load()
124+
except Exception as e:
125+
print(f" [error] Could not load: {e}")
126+
continue
127+
128+
model = adapter.model
129+
device = next(model.parameters()).device
130+
131+
# Determine the right input for this model type
132+
if alias in ("clip",):
133+
dummy = dummy_img.to(device)
134+
# CLIP needs pixel_values kwarg for full model, but we map vision_model
135+
model_to_map = model.vision_model
136+
dummy_for_map = dummy
137+
else:
138+
model_to_map = model
139+
dummy_for_map = dummy_img.to(device)
140+
141+
print(f" Mapping {sum(1 for _ in model_to_map.named_modules()) - 1} modules...")
142+
arch = map_architecture(model_to_map, dummy_for_map, device=str(device))
143+
144+
with open(out_path, "w") as f:
145+
json.dump({
146+
"model_alias": alias,
147+
"model_size": size,
148+
"model_name": model_name,
149+
"num_modules": len(arch),
150+
"num_leaf_modules": sum(1 for a in arch if a["is_leaf"]),
151+
"total_params": sum(a["num_params"] for a in arch),
152+
"modules": arch,
153+
}, f, indent=2)
154+
155+
print(f" Saved to {out_path} ({len(arch)} modules)")
156+
157+
# Cleanup
158+
del adapter, model, model_to_map
159+
import gc
160+
gc.collect()
161+
if torch.cuda.is_available():
162+
torch.cuda.empty_cache()
163+
164+
165+
if __name__ == "__main__":
166+
map_all_models()

0 commit comments

Comments
 (0)