-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathrecon_inference.py
More file actions
executable file
·318 lines (288 loc) · 10.2 KB
/
recon_inference.py
File metadata and controls
executable file
·318 lines (288 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import os
import shutil
import sys
import warnings # Added to handle warnings in the main block
import argparse
from PIL import Image
from functools import partial
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from source.dataset import EEGDataset
from source.models import ENIGMA
from source.utils import get_eegfeatures, set_seed, compute_retrieval_metrics
import pandas as pd
# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
def recon_inference(
model_path,
config_name,
cache_path,
output_path,
model_name,
subj_ids,
retrieval_only,
repetitions,
seed,
):
# Set random seeds for reproducibility
if seed is not None:
set_seed(seed)
# Set device
device = torch.device("cuda")
subjects = [f"sub-{subj:02d}" for subj in subj_ids]
torch.hub.set_dir(cache_path)
output_path = os.path.join(output_path, model_name)
model_path = os.path.join(model_path, model_name)
os.makedirs(output_path, exist_ok=True)
os.makedirs(model_path, exist_ok=True)
test_dataset = EEGDataset(
config_name,
subjects=[subjects[0]],
split="test",
)
# Transformer backbone
num_channels, num_timepoints = (
test_dataset.eeg_data.shape[-2],
test_dataset.eeg_data.shape[-1],
)
if not retrieval_only:
from source.models import SDXL_Reconstructor
embed_dim = 1024
generator = SDXL_Reconstructor(device=device, cache_dir=cache_path)
model = ENIGMA(
num_channels,
num_timepoints,
subjects=subjects,
embed_dim=1024,
retrieval_only=retrieval_only,
)
# Assert that the checkpoint exists
assert os.path.exists(f"{model_path}/last.pth"), (
"checkpoint not found at"
f" {model_path}/last.pth"
" if using best.pth, did you train with a validation set?"
)
model_weights = torch.load(f"{model_path}/last.pth", map_location=device)
# Load the checkpoint
model.load_state_dict(model_weights, strict=False)
model = model.to(device)
model.eval()
for subject in subjects:
# Load the data
test_dataset = EEGDataset(
config_name,
subjects=[subject],
split="test",
)
subject_dataloader = DataLoader(
test_dataset,
batch_size=200,
shuffle=False,
num_workers=0, # Adjust based on CPU cores
pin_memory=False,
drop_last=False, # Ensure all batches have the same size
)
# Extract EEG features
eeg_test_features = get_eegfeatures(
model, subject_dataloader, device, "test", output_path
)
# retrieval grid creation and benchmarking
ground_truth_embeds = test_dataset.get_image_features()
ground_truth_embeds /= ground_truth_embeds.norm(dim=-1, keepdim=True)
ground_truth_images = test_dataset.get_images()
if retrieval_only:
# Compute retrieval metrics
topk_accuracy = compute_retrieval_metrics(
eeg_test_features, ground_truth_embeds
)
df = pd.DataFrame(
[{f"Top {k} Accuracy": v for k, v in topk_accuracy.items()}]
)
df.to_csv(
os.path.join(
model_path, f"retrieval_statistics_{subject}.csv"
),
index=False,
)
else:
# Configure DataLoader
test_dataloader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=1, # Adjust based on CPU cores
pin_memory=False,
drop_last=False, # Ensure all batches have the same size
persistent_workers=True,
)
# Inference Loop
with torch.no_grad():
for batch in tqdm(
test_dataloader,
desc=f"Subject {subject} Reconstruction loop",
file=sys.stdout,
):
stimuli = int(batch.class_id[0])
sample_path = os.path.join(
output_path, f"reconstructions_{subject}", str(stimuli)
)
os.makedirs(sample_path, exist_ok=True)
# Save ground truth image
gt_image_path = batch.img_path[0]
Image.open(gt_image_path).save(
os.path.join(sample_path, "gt_image.jpg"),
format="JPEG", # Explicitly set format
quality=75, # Quality from 1 (worst) to 95 (best); lower means more compression
optimize=True, # Optimize the Huffman tables
progressive=True,
)
torch.save(
eeg_test_features[stimuli].cpu(),
os.path.join(sample_path, "predicted_embeds.pt"),
)
# Generate the images
images = []
for rep in range(repetitions):
image = generator.reconstruct(
c_i=eeg_test_features[stimuli].unsqueeze(0),
n_samples=1,
)
# Save the PIL Image
image.save(
os.path.join(sample_path, f"{rep}.jpg"),
format="JPEG", # Explicitly set format
quality=75, # Quality from 1 (worst) to 95 (best); lower means more compression
optimize=True, # Optimize the Huffman tables
progressive=True,
) # Create a progressive JPEG)
final_recons = torch.zeros(
(len(test_dataset), repetitions, 3, 224, 224)
)
final_embeds = torch.zeros((len(test_dataset), embed_dim))
for stimulus in range(len(test_dataset)):
embeds_path = os.path.join(
output_path,
f"reconstructions_{subject}",
f"{stimulus}",
"predicted_embeds.pt",
)
for rep in range(repetitions):
recon_path = os.path.join(
output_path,
f"reconstructions_{subject}",
f"{stimulus}",
f"{rep}.jpg",
)
if os.path.exists(recon_path):
recon_image = Image.open(recon_path)
final_recons[stimulus, rep] = transforms.ToTensor()(
recon_image.resize((224, 224))
)
else:
print(
"Reconstruction not found for"
f" stimulus {stimulus}, repetition {rep}"
)
# Save the aggregated reconstructions and embeddings
torch.save(
final_recons,
os.path.join(output_path, f"final_recons_{subject}.pt"),
)
def main():
"""
Main function to parse arguments and spawn distributed processes.
"""
parser = argparse.ArgumentParser(
description=(
"Run Reconstruction Inference with Distributed Data Parallel"
" (DDP)."
)
)
parser.add_argument(
"--model_path",
type=str,
default="train_logs",
help="Path to where model weights and training metadata are stored.",
)
parser.add_argument(
"--cache_path",
type=str,
default="cache",
help=(
"Path to where misc. files downloaded from HuggingFace or Torch"
" Hub are stored. Defaults to shared directory."
),
)
parser.add_argument(
"--output_path",
type=str,
default="output",
help="Path to where the features and reconstructions are stored.",
)
parser.add_argument(
"--config_name",
type=str,
default="things_eeg2",
help=(
"Name of the config to load for the dataset (looks in configs"
" directory)."
),
)
parser.add_argument(
"--model_name",
type=str,
default="ENIGMA",
help="Name of model, used for checkpoint saving",
)
parser.add_argument(
"--subj_ids",
nargs="+",
type=int,
default=[1],
help="List of subject IDs to train on.",
)
parser.add_argument(
"--retrieval_only",
action=argparse.BooleanOptionalAction,
default=False,
help="Only perform retrieval grid creation, no reconstructions.",
)
parser.add_argument(
"--repetitions",
type=int,
default=10,
help="Number of repetitions to sample for each sample.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Seed for random number generators.",
)
args = parser.parse_args()
# Print arguments to the slurm log for reference
print("recon_inference.py ARGUMENTS:\n-----------------------")
for arg, value in vars(args).items():
print(f"{arg}: {value}")
print("-----------------------")
# Multi-gpu will use all devices available to it, this is intended to be controlled via the devices you allocate with SLURM
world_size = torch.cuda.device_count()
print(f"World size: {world_size}")
if world_size < 1:
raise ValueError("No GPUs available for reconstruction.")
recon_inference(
model_path=args.model_path,
config_name=args.config_name,
cache_path=args.cache_path,
output_path=args.output_path,
model_name=args.model_name,
subj_ids=args.subj_ids,
retrieval_only=args.retrieval_only,
repetitions=args.repetitions,
seed=args.seed,
)
if __name__ == "__main__":
main()