Skip to content

Commit 9391807

Browse files
author
Jeff Yang
authored
feat(template): add a gan template (#22)
- Add DCGAN template - Add tests
1 parent e9b7de0 commit 9391807

11 files changed

+783
-1
lines changed

templates/gan/config.toml

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
[dataset.selectbox]
2+
label = 'Dataset to use (dataset)'
3+
options = ["cifar10", "lsun", "imagenet", "folder", "lfw", "fake", "mnist"]
4+
5+
[data_path.text_input]
6+
label = 'Dataset path (data_path)'
7+
value = './'
8+
9+
[filepath.text_input]
10+
label = 'Logging file path (filepath)'
11+
value = './logs'
12+
13+
[saved_G.text_input]
14+
label = 'Path to saved generator (saved_G)'
15+
value = '.'
16+
17+
[saved_D.text_input]
18+
label = 'Path to saved discriminator (saved_D)'
19+
value = '.'
20+
21+
[batch_size.number_input]
22+
label = 'Train batch size (batch_size)'
23+
min_value = 0
24+
value = 4
25+
26+
[num_workers.number_input]
27+
label = 'Number of workers (num_workers)'
28+
min_value = 0
29+
value = 2
30+
31+
[max_epochs.number_input]
32+
label = 'Maximum epochs to train (max_epochs)'
33+
min_value = 1
34+
value = 2
35+
36+
[lr.number_input]
37+
label = 'Learning rate used by torch.optim.* (lr)'
38+
min_value = 0.0
39+
value = 1e-3
40+
format = '%e'
41+
42+
[log_train.number_input]
43+
label = 'Logging interval of training iterations (log_train)'
44+
min_value = 0
45+
value = 50
46+
47+
[seed.number_input]
48+
label = 'Seed used in ignite.utils.manual_seed() (seed)'
49+
min_value = 0
50+
value = 666
51+
52+
[nproc_per_node.number_input]
53+
label = 'Number of processes to launch on each node (nproc_per_node)'
54+
min_value = 1
55+
56+
[nnodes.number_input]
57+
label = 'Number of nodes to use for distributed training (nnodes)'
58+
min_value = 1
59+
60+
[node_rank.number_input]
61+
label = 'Rank of the node for multi-node distributed training (node_rank)'
62+
min_value = 0
63+
64+
[master_addr.text_input]
65+
label = 'Master node TCP/IP address for torch native backends (master_addr)'
66+
value = "'127.0.0.1'"
67+
68+
[master_port.number_input]
69+
label = 'Master node port for torch native backends (master_port)'
70+
value = 8080
71+
72+
[n_saved.number_input]
73+
label = 'Number of best models to store (n_saved)'
74+
min_value = 1
75+
value = 2
76+
77+
[z_dim.number_input]
78+
label = 'Size of the latent z vector (z_dim)'
79+
value = 100
80+
81+
[alpha.number_input]
82+
label = 'Running average decay factor (alpha)'
83+
value = 0.98
84+
85+
[g_filters.number_input]
86+
label = 'Number of filters in the second-to-last generator deconv layer (g_filters)'
87+
value = 64
88+
89+
[d_filters.number_input]
90+
label = 'Number of filters in first discriminator conv layer (d_filters)'
91+
value = 64
92+
93+
[beta_1.number_input]
94+
label = 'beta_1 for Adam optimizer (beta_1)'
95+
value = 0.5

templates/gan/datasets.py.jinja

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from torchvision import transforms as T
2+
from torchvision import datasets as dset
3+
4+
5+
def get_datasets(dataset, dataroot):
6+
"""
7+
8+
Args:
9+
dataset (str): Name of the dataset to use. See CLI help for details
10+
dataroot (str): root directory where the dataset will be stored.
11+
12+
Returns:
13+
dataset, num_channels
14+
"""
15+
resize = T.Resize(64)
16+
crop = T.CenterCrop(64)
17+
to_tensor = T.ToTensor()
18+
normalize = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
19+
20+
if dataset in {"imagenet", "folder", "lfw"}:
21+
dataset = dset.ImageFolder(root=dataroot, transform=T.Compose([resize, crop, to_tensor, normalize]))
22+
nc = 3
23+
24+
elif dataset == "lsun":
25+
dataset = dset.LSUN(
26+
root=dataroot, classes=["bedroom_train"], transform=T.Compose([resize, crop, to_tensor, normalize])
27+
)
28+
nc = 3
29+
30+
elif dataset == "cifar10":
31+
dataset = dset.CIFAR10(root=dataroot, download=True, transform=T.Compose([resize, to_tensor, normalize]))
32+
nc = 3
33+
34+
elif dataset == "mnist":
35+
dataset = dset.MNIST(root=dataroot, download=True, transform=T.Compose([resize, to_tensor, normalize]))
36+
nc = 1
37+
38+
elif dataset == "fake":
39+
dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor)
40+
nc = 3
41+
42+
else:
43+
raise RuntimeError(f"Invalid dataset name: {dataset}")
44+
45+
return dataset, nc

templates/gan/fn.py.jinja

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
3+
4+
def update(netD, netG, device, optimizerD, optimizerG, loss_fn, config, real_labels, fake_labels):
5+
6+
# The main function, processing a batch of examples
7+
def step(engine, batch):
8+
9+
# unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
10+
real, _ = batch
11+
real = real.to(device)
12+
13+
# -----------------------------------------------------------
14+
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
15+
netD.zero_grad()
16+
17+
# train with real
18+
output = netD(real)
19+
errD_real = loss_fn(output, real_labels)
20+
D_x = output.mean().item()
21+
22+
errD_real.backward()
23+
24+
# get fake image from generator
25+
noise = torch.randn(config.batch_size, config.z_dim, 1, 1, device=device)
26+
fake = netG(noise)
27+
28+
# train with fake
29+
output = netD(fake.detach())
30+
errD_fake = loss_fn(output, fake_labels)
31+
D_G_z1 = output.mean().item()
32+
33+
errD_fake.backward()
34+
35+
# gradient update
36+
errD = errD_real + errD_fake
37+
optimizerD.step()
38+
39+
# -----------------------------------------------------------
40+
# (2) Update G network: maximize log(D(G(z)))
41+
netG.zero_grad()
42+
43+
# Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
44+
output = netD(fake)
45+
errG = loss_fn(output, real_labels)
46+
D_G_z2 = output.mean().item()
47+
48+
errG.backward()
49+
50+
# gradient update
51+
optimizerG.step()
52+
53+
return {"errD": errD.item(), "errG": errG.item(), "D_x": D_x, "D_G_z1": D_G_z1, "D_G_z2": D_G_z2}
54+
55+
return step

0 commit comments

Comments
 (0)