|
14 | 14 | import time
|
15 | 15 |
|
16 | 16 | parser = argparse.ArgumentParser()
|
17 |
| -parser.add_argument("--input_dir", help="path to folder containing images") |
| 17 | +parser.add_argument("--input_dir", default='../../../data/GoogleArt_wikimedia', help="path to folder containing datasets") |
| 18 | +parser.add_argument("--dataset", default='gart_256_p2p_ds2_crop2h', help="name of folder containing images in input_dir") |
| 19 | +parser.add_argument("--output_dir", default='./out', help="where to put output files (dataset name will be appended") |
18 | 20 | parser.add_argument("--mode", required=True, choices=["train", "test", "export"])
|
19 |
| -parser.add_argument("--output_dir", required=True, help="where to put output files") |
20 | 21 | parser.add_argument("--seed", type=int)
|
21 |
| -parser.add_argument("--checkpoint", default=None, help="directory with checkpoint to resume training from or use for testing") |
| 22 | +parser.add_argument("--checkpoint", help="directory with checkpoint to resume training from or use for testing") |
22 | 23 |
|
23 | 24 | parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)")
|
24 |
| -parser.add_argument("--max_epochs", type=int, help="number of training epochs") |
| 25 | +parser.add_argument("--max_epochs", default=200, type=int, help="number of training epochs") |
25 | 26 | parser.add_argument("--summary_freq", type=int, default=100, help="update summaries every summary_freq steps")
|
26 | 27 | parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps")
|
27 | 28 | parser.add_argument("--trace_freq", type=int, default=0, help="trace execution every trace_freq steps")
|
28 |
| -parser.add_argument("--display_freq", type=int, default=0, help="write current training images every display_freq steps") |
| 29 | +parser.add_argument("--display_freq", type=int, default=5000, help="write current training images every display_freq steps") |
29 | 30 | parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable")
|
30 | 31 |
|
31 | 32 | parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)")
|
32 | 33 | parser.add_argument("--lab_colorization", action="store_true", help="split input image into brightness (A) and color (B)")
|
33 |
| -parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch") |
34 |
| -parser.add_argument("--which_direction", type=str, default="AtoB", choices=["AtoB", "BtoA"]) |
| 34 | +parser.add_argument("--batch_size", type=int, default=4, help="number of images in batch") |
| 35 | +parser.add_argument("--which_direction", type=str, default="BtoA", choices=["AtoB", "BtoA"]) |
35 | 36 | parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
|
36 | 37 | parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer")
|
37 |
| -parser.add_argument("--scale_size", type=int, default=286, help="scale images to this size before cropping to 256x256") |
| 38 | +parser.add_argument("--scale_size", type=int, default=256, help="scale images to this size before cropping to 256x256") |
38 | 39 | parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally")
|
39 | 40 | parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally")
|
40 |
| -parser.set_defaults(flip=True) |
| 41 | +parser.set_defaults(flip=False) |
41 | 42 | parser.add_argument("--lr", type=float, default=0.0002, help="initial learning rate for adam")
|
42 | 43 | parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam")
|
43 | 44 | parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient")
|
|
46 | 47 | # export options
|
47 | 48 | parser.add_argument("--output_filetype", default="png", choices=["png", "jpeg"])
|
48 | 49 | a = parser.parse_args()
|
| 50 | +a.input_dir = os.path.join(a.input_dir, a.dataset) |
| 51 | +a.output_dir = os.path.join(a.output_dir, a.dataset) |
| 52 | + |
| 53 | +if a.checkpoint is not None and len(a.checkpoint) > 0: |
| 54 | + a.checkpoint = a.output_dir |
| 55 | + |
| 56 | +if a.checkpoint is None and a.mode != "train": |
| 57 | + a.checkpoint = a.output_dir |
49 | 58 |
|
50 | 59 | EPS = 1e-12
|
51 |
| -CROP_SIZE = 256 |
| 60 | +CROP_SIZE = a.scale_size |
52 | 61 |
|
53 | 62 | Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch")
|
54 | 63 | Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train")
|
@@ -502,12 +511,14 @@ def save_images(fetches, step=None):
|
502 | 511 |
|
503 | 512 | filesets = []
|
504 | 513 | for i, in_path in enumerate(fetches["paths"]):
|
505 |
| - name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8"))) |
| 514 | + # name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8"))) |
| 515 | + name = str(i) |
506 | 516 | fileset = {"name": name, "step": step}
|
507 | 517 | for kind in ["inputs", "outputs", "targets"]:
|
508 |
| - filename = name + "-" + kind + ".png" |
| 518 | + filename = name + "-" + kind |
509 | 519 | if step is not None:
|
510 | 520 | filename = "%08d-%s" % (step, filename)
|
| 521 | + filename = filename[:130] + ".png" # prevent errors with too long filename! |
511 | 522 | fileset[kind] = filename
|
512 | 523 | out_path = os.path.join(image_dir, filename)
|
513 | 524 | contents = fetches[kind][i]
|
@@ -558,7 +569,8 @@ def main():
|
558 | 569 |
|
559 | 570 | if a.mode == "test" or a.mode == "export":
|
560 | 571 | if a.checkpoint is None:
|
561 |
| - raise Exception("checkpoint required for test mode") |
| 572 | + a.checkpoint = a.output_dir |
| 573 | + #raise Exception("checkpoint required for test mode") |
562 | 574 |
|
563 | 575 | # load some options from the checkpoint
|
564 | 576 | options = {"which_direction", "ngf", "ndf", "lab_colorization"}
|
@@ -717,7 +729,7 @@ def convert(image):
|
717 | 729 | with tf.name_scope("parameter_count"):
|
718 | 730 | parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
|
719 | 731 |
|
720 |
| - saver = tf.train.Saver(max_to_keep=1) |
| 732 | + saver = tf.train.Saver(max_to_keep=5) |
721 | 733 |
|
722 | 734 | logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
|
723 | 735 | sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
|
|
0 commit comments