Skip to content

feat(download): add an option to archive and download #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions app/codegen.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,39 @@
"""Code Generator base module.
"""
import shutil
from pathlib import Path

from jinja2 import Environment, FileSystemLoader


class CodeGenerator:
def __init__(self, templates_dir=None):
templates_dir = templates_dir or "./templates"
self.template_list = [p.stem for p in Path(templates_dir).iterdir() if p.is_dir()]
self.env = Environment(
loader=FileSystemLoader(templates_dir), trim_blocks=True, lstrip_blocks=True
)
self.env = Environment(loader=FileSystemLoader(templates_dir), trim_blocks=True, lstrip_blocks=True)

def render_templates(self, template_name: str, config: dict):
"""Renders all the templates from template folder for the given config.
"""
file_template_list = (
template
for template in self.env.list_templates(".jinja")
if template.startswith(template_name)
template for template in self.env.list_templates(".jinja") if template.startswith(template_name)
)
for fname in file_template_list:
# Get template
template = self.env.get_template(fname)
# Render template
code = template.render(**config)
# Write python file
fname = fname.strip(f"{template_name}/").strip(".jinja")
fname = fname.replace(f"{template_name}/", "").replace(".jinja", "")
self.generate(template_name, fname, code)
yield fname, code

def generate(self, template_name: str, fname: str, code: str) -> None:
"""Generates `fname` with content `code` in `path`.
"""
path = Path(f"dist/{template_name}")
path.mkdir(parents=True, exist_ok=True)
(path / fname).write_text(code)
self.path = Path(f"./dist/{template_name}")
self.path.mkdir(parents=True, exist_ok=True)
(self.path / fname).write_text(code)

def make_archive(self):
raise NotImplementedError
def make_archive(self, format_):
return shutil.make_archive(base_name=str(self.path), format=format_, base_dir=self.path)
24 changes: 23 additions & 1 deletion app/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import streamlit as st
import shutil
from pathlib import Path

import streamlit as st
from codegen import CodeGenerator
from utils import import_from_file

Expand Down Expand Up @@ -60,9 +62,29 @@ def add_content(self):
for fname, code in content:
self.render_code(fname, code, fold)

def add_download(self):
st.markdown("")
format_ = st.radio(
"Archive format",
[name for name, _ in sorted(shutil.get_archive_formats(), key=lambda x: x[0], reverse=True)],
)
# temporary hack until streamlit has official download option
# https://github.com/streamlit/streamlit/issues/400
# https://github.com/streamlit/streamlit/issues/400#issuecomment-648580840
if st.button("Generate an archive"):
archive_fname = self.codegen.make_archive(format_)
# this is where streamlit serves static files
# ~/site-packages/streamlit/static/static/
dist_path = Path(st.__path__[0]) / "static/static/dist"
if not dist_path.is_dir():
dist_path.mkdir()
shutil.copy(archive_fname, dist_path)
st.success(f"Download link : [{archive_fname}](./static/{archive_fname})")

def run(self):
self.add_sidebar()
self.add_content()
self.add_download()


def main():
Expand Down
62 changes: 16 additions & 46 deletions templates/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,74 +4,44 @@
def get_configs() -> dict:
config = {}
with st.beta_expander("Training Configurations"):
st.info(
"Common base training configurations. Those in the parenthesis are used in the code."
)
st.info("Common base training configurations. Those in the parenthesis are used in the code.")

# group by streamlit function type
config["amp_mode"] = st.selectbox(
"AMP mode (amp_mode)", ("None", "amp", "apex")
)
config["device"] = st.selectbox(
"Device to use (device)", ("cpu", "cuda", "xla")
)
config["amp_mode"] = st.selectbox("AMP mode (amp_mode)", ("None", "amp", "apex"))
config["device"] = st.selectbox("Device to use (device)", ("cpu", "cuda", "xla"))

config["data_path"] = st.text_input("Dataset path (data_path)", value="./")
config["filepath"] = st.text_input(
"Logging file path (filepath)", value="./logs"
)
config["filepath"] = st.text_input("Logging file path (filepath)", value="./logs")

config["train_batch_size"] = st.number_input(
"Train batch size (train_batch_size)", min_value=1, value=1
)
config["eval_batch_size"] = st.number_input(
"Eval batch size (eval_batch_size)", min_value=1, value=1
)
config["num_workers"] = st.number_input(
"Number of workers (num_workers)", min_value=0, value=2
)
config["max_epochs"] = st.number_input(
"Maximum epochs to train (max_epochs)", min_value=1, value=2
)
config["train_batch_size"] = st.number_input("Train batch size (train_batch_size)", min_value=1, value=1)
config["eval_batch_size"] = st.number_input("Eval batch size (eval_batch_size)", min_value=1, value=1)
config["num_workers"] = st.number_input("Number of workers (num_workers)", min_value=0, value=2)
config["max_epochs"] = st.number_input("Maximum epochs to train (max_epochs)", min_value=1, value=2)
config["lr"] = st.number_input(
"Learning rate used by torch.optim.* (lr)",
min_value=0.0,
value=1e-3,
format="%e",
"Learning rate used by torch.optim.* (lr)", min_value=0.0, value=1e-3, format="%e",
)
config["log_train"] = st.number_input(
"Logging interval of training iterations (log_train)", min_value=0, value=50
)
config["log_eval"] = st.number_input(
"Logging interval of evaluation epoch (log_eval)", min_value=0, value=1
)
config["seed"] = st.number_input(
"Seed used in ignite.utils.manual_seed() (seed)", min_value=0, value=666
)
config["log_eval"] = st.number_input("Logging interval of evaluation epoch (log_eval)", min_value=0, value=1)
config["seed"] = st.number_input("Seed used in ignite.utils.manual_seed() (seed)", min_value=0, value=666)
if st.checkbox("Use distributed training"):
config["nproc_per_node"] = st.number_input(
"Number of processes to launch on each node (nproc_per_node)",
min_value=1,
)
config["nnodes"] = st.number_input(
"Number of nodes to use for distributed training (nnodes)", min_value=1,
"Number of processes to launch on each node (nproc_per_node)", min_value=1,
)
config["nnodes"] = st.number_input("Number of nodes to use for distributed training (nnodes)", min_value=1,)
if config["nnodes"] > 1:
st.info(
"The following options are only supported by torch.distributed, namely 'gloo' and 'nccl' backends."
" For other backends, please specify spawn_kwargs in main.py"
)
config["node_rank"] = st.number_input(
"Rank of the node for multi-node distributed training (node_rank)",
min_value=0,
"Rank of the node for multi-node distributed training (node_rank)", min_value=0,
)
if config["node_rank"] > (config["nnodes"] - 1):
st.error(
f"node_rank should be between 0 and {config['nnodes'] - 1}"
)
st.error(f"node_rank should be between 0 and {config['nnodes'] - 1}")
config["master_addr"] = st.text_input(
"Master node TCP/IP address for torch native backends (master_addr)",
"'127.0.0.1'",
"Master node TCP/IP address for torch native backends (master_addr)", "'127.0.0.1'",
)
st.warning("Please include single quote in master_addr.")
config["master_port"] = st.text_input(
Expand Down
2 changes: 1 addition & 1 deletion templates/base/utils.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ DEFAULTS = {
"help": "rank of the node for multi-node distributed training ({{ node_rank }})",
},
"master_addr": {
"default": {{ master_addr|safe }},
"default": {{ master_addr }},
"type": str,
"help": "master node TCP/IP address for torch native backends ({{ master_addr }})",
},
Expand Down