From 1169003ae647eb8b3a90cb6a0aab2be846a77d6e Mon Sep 17 00:00:00 2001 From: ydcjeff <ydcjeff@outlook.com> Date: Sat, 10 Apr 2021 11:49:12 +0630 Subject: [PATCH] fix: bump max_epochs to 5, get_handlers arguments --- templates/gan/_sidebar.py | 2 +- templates/gan/main.py | 3 +++ templates/image_classification/_sidebar.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/templates/gan/_sidebar.py b/templates/gan/_sidebar.py index d19bdebf..6b11a8ad 100644 --- a/templates/gan/_sidebar.py +++ b/templates/gan/_sidebar.py @@ -43,7 +43,7 @@ def optimizer_options(config): def training_options(config): st.markdown("## Training Options") - config["max_epochs"] = st.number_input("Maximum epochs to train (max_epochs)", min_value=1, value=2) + config["max_epochs"] = st.number_input("Maximum epochs to train (max_epochs)", min_value=1, value=5) st.markdown("---") diff --git a/templates/gan/main.py b/templates/gan/main.py index 1b0fbb6f..05eb3b74 100644 --- a/templates/gan/main.py +++ b/templates/gan/main.py @@ -90,6 +90,9 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): config=config, model={'netD', netD, 'netG', netG}, train_engine=train_engine, + eval_engine=None, + metric_name=None, + es_metric_name=None, to_save=to_save, lr_scheduler=lr_scheduler, output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"], diff --git a/templates/image_classification/_sidebar.py b/templates/image_classification/_sidebar.py index 224f1073..e16aa8d1 100644 --- a/templates/image_classification/_sidebar.py +++ b/templates/image_classification/_sidebar.py @@ -51,7 +51,7 @@ def optimizer_options(config): def training_options(config): st.markdown("## Training Options") - config["max_epochs"] = st.number_input("Maximum epochs to train (max_epochs)", min_value=1, value=2) + config["max_epochs"] = st.number_input("Maximum epochs to train (max_epochs)", min_value=1, value=5) config["num_warmup_epochs"] = st.number_input( "number of warm-up epochs before learning rate decay (num_warmup_epochs)", min_value=1, value=4 )