-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpotential_gan.py
More file actions
143 lines (125 loc) · 5.22 KB
/
potential_gan.py
File metadata and controls
143 lines (125 loc) · 5.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
from datetime import datetime
from neon.callbacks.callbacks import Callbacks, GANCostCallback
#from neon.callbacks.plotting_callbacks import GANPlotCallback
from neon.initializers import Gaussian
from neon.layers import GeneralizedGANCost, Affine, Sequential, Conv, Deconv, Dropout, Pooling, BatchNorm
from neon.layers.layer import Linear, Reshape
from neon.layers.container import GenerativeAdversarial
from neon.models.model import GAN, Model
from neon.transforms import Rectlin, Logistic, GANCost, Tanh
from neon.util.argparser import NeonArgparser
from neon.util.persist import ensure_dirs_exist
from neon.layers.layer import Dropout
from neon.data.dataiterator import ArrayIterator
from neon.optimizers import GradientDescentMomentum, RMSProp
from gen_data_norm import gen_rhs
from neon.backends import gen_backend
from temporary_utils import temp_3Ddata
import numpy as np
from sklearn.cross_validation import train_test_split
import matplotlib.pyplot as plt
import h5py
# load up the data set
X, y = temp_3Ddata()
X[X < 1e-6] = 0
# mean = np.mean(X, axis=0, keepdims=True)
print(np.max(X),'max element')
print(np.min(X),'min element')
# X -= mean
print(X.shape, 'X shape')
#print(np.max(X),'max element')
#print(np.min(X),'min element')
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.9, random_state=42)
print(X_train.shape, 'X train shape')
print(y_train.shape, 'y train shape')
gen_backend(backend='gpu', batch_size=100)
train_set = ArrayIterator(X=X_train, y=y_train, nclass=2, lshape=(1, 25, 25, 25))
valid_set = ArrayIterator(X=X_test, y=y_test, nclass=2)
#tate=lt.plot(X_train[0, 12])
#plt.savefigure('data_img.png')
# setup weight initialization function
init = Gaussian(scale=0.01)
# discriminiator using convolution layers
lrelu = Rectlin(slope=0.1) # leaky relu for discriminator
# sigmoid = Logistic() # sigmoid activation function
conv1 = dict(init=init, batch_norm=False, activation=lrelu, bias=init)
conv2 = dict(init=init, batch_norm=False, activation=lrelu, padding=2, bias=init)
conv3 = dict(init=init, batch_norm=False, activation=Logistic(), padding=1, bias=init)
D_layers = [
Conv((5, 5, 5, 32), **conv1),
Dropout(keep = 0.8),
Conv((5, 5, 5, 8), **conv2),
BatchNorm(),
Dropout(keep = 0.8),
Conv((5, 5, 5, 8), **conv2),
BatchNorm(),
Dropout(keep = 0.8),
Conv((5, 5, 5, 8), **conv3),
BatchNorm(),
Dropout(keep = 0.8),
Pooling((2, 2, 2)),
Affine(512, init=init),
Affine(512, init=init, bias=init),
Affine(1, init=init, bias=init, activation=Logistic())
]
# generator using convolution layers
init_gen = Gaussian(scale=0.01)
relu = Rectlin(slope=0) # relu for generator
pad1 = dict(pad_h=2, pad_w=2, pad_d=2)
str1 = dict(str_h=2, str_w=2, str_d=2)
conv1 = dict(init=init_gen, batch_norm=False, activation=lrelu, padding=pad1, strides=str1, bias=init_gen)
pad2 = dict(pad_h=2, pad_w=2, pad_d=2)
str2 = dict(str_h=2, str_w=2, str_d=2)
conv2 = dict(init=init_gen, batch_norm=False, activation=lrelu, padding=pad2, strides=str2, bias=init_gen)
pad3 = dict(pad_h=0, pad_w=0, pad_d=0)
str3 = dict(str_h=1, str_w=1, str_d=1)
conv3 = dict(init=init_gen, batch_norm=False, activation=Logistic(), padding=pad3, strides=str3, bias=init_gen)
G_layers = [
Affine(8 * 7 * 7 * 7, init=init_gen, bias=init_gen),
Reshape((8, 7, 7, 7)),
Deconv((6, 6, 6, 6), **conv1), #14x14x14
BatchNorm(),
# Linear(5 * 14 * 14 * 14, init=init),
# Reshape((5, 14, 14, 14)),
Deconv((5, 5, 5, 64), **conv2), #27x27x27
BatchNorm(),
Conv((3, 3, 3, 1), **conv3)
]
layers = GenerativeAdversarial(generator=Sequential(G_layers, name="Generator"),
discriminator=Sequential(D_layers, name="Discriminator"))
# setup optimizer
# optimizer = RMSProp(learning_rate=1e-5, decay_rate=0.9, epsilon=1e-8)
optimizer = GradientDescentMomentum(learning_rate=1e-3, momentum_coef = 0.9)
# optimizer = Adam(learning_rate=0.01)
# setup cost function as Binary CrossEntropy
cost = GeneralizedGANCost(costfunc=GANCost(func="original"))
nb_epochs = 5
latent_size = 200
inb_classes = 2
nb_test = 100
# initialize model
noise_dim = (latent_size)
gan = GAN(layers=layers, noise_dim=noise_dim)
# configure callbacks
callbacks = Callbacks(gan, eval_set=valid_set)
callbacks.add_callback(GANCostCallback())
#callbacks.add_save_best_state_callback("./best_state.pkl")
# run fit
gan.fit(train_set, num_epochs=nb_epochs, optimizer=optimizer,
cost=cost, callbacks=callbacks)
# gan.save_params('our_gan.prm')
x_new = np.random.rand(100, latent_size)
inference_set = ArrayIterator(x_new, None, nclass=2, lshape=(latent_size))
my_generator = Model(gan.layers.generator)
my_generator.save_params('our_gen.prm')
my_discriminator = Model(gan.layers.discriminator)
my_discriminator.save_params('our_disc.prm')
test = my_generator.get_outputs(inference_set)
# test += mean
test = test.reshape((100, 25, 25, 25))
print(test.shape, 'generator output')
#plt.plot(test[0, :, 12, :])
# plt.savefigure('output_img.png')
h5f = h5py.File('output_data.h5', 'w')
h5f.create_dataset('dataset_1', data=test)