Skip to content

Commit f0dd7c4

Browse files
committed
usability updates
1 parent fb99c19 commit f0dd7c4

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

pix2pix.py

100644100755
Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,31 @@
1414
import time
1515

1616
parser = argparse.ArgumentParser()
17-
parser.add_argument("--input_dir", help="path to folder containing images")
17+
parser.add_argument("--input_dir", default='../../../data/GoogleArt_wikimedia', help="path to folder containing datasets")
18+
parser.add_argument("--dataset", default='gart_256_p2p_ds2_crop2h', help="name of folder containing images in input_dir")
19+
parser.add_argument("--output_dir", default='./out', help="where to put output files (dataset name will be appended")
1820
parser.add_argument("--mode", required=True, choices=["train", "test", "export"])
19-
parser.add_argument("--output_dir", required=True, help="where to put output files")
2021
parser.add_argument("--seed", type=int)
21-
parser.add_argument("--checkpoint", default=None, help="directory with checkpoint to resume training from or use for testing")
22+
parser.add_argument("--checkpoint", help="directory with checkpoint to resume training from or use for testing")
2223

2324
parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)")
24-
parser.add_argument("--max_epochs", type=int, help="number of training epochs")
25+
parser.add_argument("--max_epochs", default=200, type=int, help="number of training epochs")
2526
parser.add_argument("--summary_freq", type=int, default=100, help="update summaries every summary_freq steps")
2627
parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps")
2728
parser.add_argument("--trace_freq", type=int, default=0, help="trace execution every trace_freq steps")
28-
parser.add_argument("--display_freq", type=int, default=0, help="write current training images every display_freq steps")
29+
parser.add_argument("--display_freq", type=int, default=5000, help="write current training images every display_freq steps")
2930
parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable")
3031

3132
parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)")
3233
parser.add_argument("--lab_colorization", action="store_true", help="split input image into brightness (A) and color (B)")
33-
parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch")
34-
parser.add_argument("--which_direction", type=str, default="AtoB", choices=["AtoB", "BtoA"])
34+
parser.add_argument("--batch_size", type=int, default=4, help="number of images in batch")
35+
parser.add_argument("--which_direction", type=str, default="BtoA", choices=["AtoB", "BtoA"])
3536
parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
3637
parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer")
37-
parser.add_argument("--scale_size", type=int, default=286, help="scale images to this size before cropping to 256x256")
38+
parser.add_argument("--scale_size", type=int, default=256, help="scale images to this size before cropping to 256x256")
3839
parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally")
3940
parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally")
40-
parser.set_defaults(flip=True)
41+
parser.set_defaults(flip=False)
4142
parser.add_argument("--lr", type=float, default=0.0002, help="initial learning rate for adam")
4243
parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam")
4344
parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient")
@@ -46,9 +47,17 @@
4647
# export options
4748
parser.add_argument("--output_filetype", default="png", choices=["png", "jpeg"])
4849
a = parser.parse_args()
50+
a.input_dir = os.path.join(a.input_dir, a.dataset)
51+
a.output_dir = os.path.join(a.output_dir, a.dataset)
52+
53+
if a.checkpoint is not None and len(a.checkpoint) > 0:
54+
a.checkpoint = a.output_dir
55+
56+
if a.checkpoint is None and a.mode != "train":
57+
a.checkpoint = a.output_dir
4958

5059
EPS = 1e-12
51-
CROP_SIZE = 256
60+
CROP_SIZE = a.scale_size
5261

5362
Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch")
5463
Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train")
@@ -502,12 +511,14 @@ def save_images(fetches, step=None):
502511

503512
filesets = []
504513
for i, in_path in enumerate(fetches["paths"]):
505-
name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8")))
514+
# name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8")))
515+
name = str(i)
506516
fileset = {"name": name, "step": step}
507517
for kind in ["inputs", "outputs", "targets"]:
508-
filename = name + "-" + kind + ".png"
518+
filename = name + "-" + kind
509519
if step is not None:
510520
filename = "%08d-%s" % (step, filename)
521+
filename = filename[:130] + ".png" # prevent errors with too long filename!
511522
fileset[kind] = filename
512523
out_path = os.path.join(image_dir, filename)
513524
contents = fetches[kind][i]
@@ -558,7 +569,8 @@ def main():
558569

559570
if a.mode == "test" or a.mode == "export":
560571
if a.checkpoint is None:
561-
raise Exception("checkpoint required for test mode")
572+
a.checkpoint = a.output_dir
573+
#raise Exception("checkpoint required for test mode")
562574

563575
# load some options from the checkpoint
564576
options = {"which_direction", "ngf", "ndf", "lab_colorization"}
@@ -717,7 +729,7 @@ def convert(image):
717729
with tf.name_scope("parameter_count"):
718730
parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
719731

720-
saver = tf.train.Saver(max_to_keep=1)
732+
saver = tf.train.Saver(max_to_keep=5)
721733

722734
logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
723735
sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)

0 commit comments

Comments
 (0)