-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
78 lines (47 loc) · 1.95 KB
/
inference.py
File metadata and controls
78 lines (47 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from diffusion import create_cifar10_ddpm
import os
def load_model_from_checkpoint(checkpoint_path, device='cuda'):
"""Load trained model from checkpoint"""
model = create_cifar10_ddpm()
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
return model
def generate_samples(model, num_samples=16, device='cuda', save_path='generated_samples.png'):
with torch.no_grad():
samples = model.sample(num_samples, device)
samples = (samples + 1) / 2
samples = torch.clamp(samples, 0, 1)
vutils.save_image(samples, save_path, nrow=4, padding=2)
print(f"💾 Samples saved to {save_path}")
return samples
def display_samples(samples, title="Generated CIFAR-10 Images"):
grid = vutils.make_grid(samples, nrow=4, padding=2)
grid = grid.permute(1, 2, 0).cpu().numpy()
plt.figure(figsize=(12, 12))
plt.imshow(grid)
plt.title(title, fontsize=16)
plt.axis('off')
plt.tight_layout()
plt.show()
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔧 Using device: {device}")
checkpoint_path = '../Downloads/ddpm_epoch_70.pth'
if not os.path.exists(checkpoint_path):
print(f"❌ Checkpoint not found: {checkpoint_path}")
print("Available checkpoints:")
if os.path.exists('checkpoints/'):
for f in os.listdir('checkpoints/'):
if f.endswith('.pth'):
print(f" - checkpoints/{f}")
return
model = load_model_from_checkpoint(checkpoint_path, device)
samples = generate_samples(model, num_samples=16, device=device)
display_samples(samples)
if __name__ == "__main__":
main()