Skip to content

Commit 3f87b20

Browse files
author
Jeff Yang
authored
feat(download): add an option to archive and download (#6)
1 parent 25fdd92 commit 3f87b20

File tree

4 files changed

+50
-60
lines changed

4 files changed

+50
-60
lines changed

app/codegen.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,39 @@
11
"""Code Generator base module.
22
"""
3+
import shutil
34
from pathlib import Path
5+
46
from jinja2 import Environment, FileSystemLoader
57

68

79
class CodeGenerator:
810
def __init__(self, templates_dir=None):
911
templates_dir = templates_dir or "./templates"
1012
self.template_list = [p.stem for p in Path(templates_dir).iterdir() if p.is_dir()]
11-
self.env = Environment(
12-
loader=FileSystemLoader(templates_dir), trim_blocks=True, lstrip_blocks=True
13-
)
13+
self.env = Environment(loader=FileSystemLoader(templates_dir), trim_blocks=True, lstrip_blocks=True)
1414

1515
def render_templates(self, template_name: str, config: dict):
1616
"""Renders all the templates from template folder for the given config.
1717
"""
1818
file_template_list = (
19-
template
20-
for template in self.env.list_templates(".jinja")
21-
if template.startswith(template_name)
19+
template for template in self.env.list_templates(".jinja") if template.startswith(template_name)
2220
)
2321
for fname in file_template_list:
2422
# Get template
2523
template = self.env.get_template(fname)
2624
# Render template
2725
code = template.render(**config)
2826
# Write python file
29-
fname = fname.strip(f"{template_name}/").strip(".jinja")
27+
fname = fname.replace(f"{template_name}/", "").replace(".jinja", "")
3028
self.generate(template_name, fname, code)
3129
yield fname, code
3230

3331
def generate(self, template_name: str, fname: str, code: str) -> None:
3432
"""Generates `fname` with content `code` in `path`.
3533
"""
36-
path = Path(f"dist/{template_name}")
37-
path.mkdir(parents=True, exist_ok=True)
38-
(path / fname).write_text(code)
34+
self.path = Path(f"./dist/{template_name}")
35+
self.path.mkdir(parents=True, exist_ok=True)
36+
(self.path / fname).write_text(code)
3937

40-
def make_archive(self):
41-
raise NotImplementedError
38+
def make_archive(self, format_):
39+
return shutil.make_archive(base_name=str(self.path), format=format_, base_dir=self.path)

app/streamlit_app.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
import streamlit as st
1+
import shutil
2+
from pathlib import Path
23

4+
import streamlit as st
35
from codegen import CodeGenerator
46
from utils import import_from_file
57

@@ -60,9 +62,29 @@ def add_content(self):
6062
for fname, code in content:
6163
self.render_code(fname, code, fold)
6264

65+
def add_download(self):
66+
st.markdown("")
67+
format_ = st.radio(
68+
"Archive format",
69+
[name for name, _ in sorted(shutil.get_archive_formats(), key=lambda x: x[0], reverse=True)],
70+
)
71+
# temporary hack until streamlit has official download option
72+
# https://github.com/streamlit/streamlit/issues/400
73+
# https://github.com/streamlit/streamlit/issues/400#issuecomment-648580840
74+
if st.button("Generate an archive"):
75+
archive_fname = self.codegen.make_archive(format_)
76+
# this is where streamlit serves static files
77+
# ~/site-packages/streamlit/static/static/
78+
dist_path = Path(st.__path__[0]) / "static/static/dist"
79+
if not dist_path.is_dir():
80+
dist_path.mkdir()
81+
shutil.copy(archive_fname, dist_path)
82+
st.success(f"Download link : [{archive_fname}](./static/{archive_fname})")
83+
6384
def run(self):
6485
self.add_sidebar()
6586
self.add_content()
87+
self.add_download()
6688

6789

6890
def main():

templates/base/config.py

+16-46
Original file line numberDiff line numberDiff line change
@@ -4,74 +4,44 @@
44
def get_configs() -> dict:
55
config = {}
66
with st.beta_expander("Training Configurations"):
7-
st.info(
8-
"Common base training configurations. Those in the parenthesis are used in the code."
9-
)
7+
st.info("Common base training configurations. Those in the parenthesis are used in the code.")
108

119
# group by streamlit function type
12-
config["amp_mode"] = st.selectbox(
13-
"AMP mode (amp_mode)", ("None", "amp", "apex")
14-
)
15-
config["device"] = st.selectbox(
16-
"Device to use (device)", ("cpu", "cuda", "xla")
17-
)
10+
config["amp_mode"] = st.selectbox("AMP mode (amp_mode)", ("None", "amp", "apex"))
11+
config["device"] = st.selectbox("Device to use (device)", ("cpu", "cuda", "xla"))
1812

1913
config["data_path"] = st.text_input("Dataset path (data_path)", value="./")
20-
config["filepath"] = st.text_input(
21-
"Logging file path (filepath)", value="./logs"
22-
)
14+
config["filepath"] = st.text_input("Logging file path (filepath)", value="./logs")
2315

24-
config["train_batch_size"] = st.number_input(
25-
"Train batch size (train_batch_size)", min_value=1, value=1
26-
)
27-
config["eval_batch_size"] = st.number_input(
28-
"Eval batch size (eval_batch_size)", min_value=1, value=1
29-
)
30-
config["num_workers"] = st.number_input(
31-
"Number of workers (num_workers)", min_value=0, value=2
32-
)
33-
config["max_epochs"] = st.number_input(
34-
"Maximum epochs to train (max_epochs)", min_value=1, value=2
35-
)
16+
config["train_batch_size"] = st.number_input("Train batch size (train_batch_size)", min_value=1, value=1)
17+
config["eval_batch_size"] = st.number_input("Eval batch size (eval_batch_size)", min_value=1, value=1)
18+
config["num_workers"] = st.number_input("Number of workers (num_workers)", min_value=0, value=2)
19+
config["max_epochs"] = st.number_input("Maximum epochs to train (max_epochs)", min_value=1, value=2)
3620
config["lr"] = st.number_input(
37-
"Learning rate used by torch.optim.* (lr)",
38-
min_value=0.0,
39-
value=1e-3,
40-
format="%e",
21+
"Learning rate used by torch.optim.* (lr)", min_value=0.0, value=1e-3, format="%e",
4122
)
4223
config["log_train"] = st.number_input(
4324
"Logging interval of training iterations (log_train)", min_value=0, value=50
4425
)
45-
config["log_eval"] = st.number_input(
46-
"Logging interval of evaluation epoch (log_eval)", min_value=0, value=1
47-
)
48-
config["seed"] = st.number_input(
49-
"Seed used in ignite.utils.manual_seed() (seed)", min_value=0, value=666
50-
)
26+
config["log_eval"] = st.number_input("Logging interval of evaluation epoch (log_eval)", min_value=0, value=1)
27+
config["seed"] = st.number_input("Seed used in ignite.utils.manual_seed() (seed)", min_value=0, value=666)
5128
if st.checkbox("Use distributed training"):
5229
config["nproc_per_node"] = st.number_input(
53-
"Number of processes to launch on each node (nproc_per_node)",
54-
min_value=1,
55-
)
56-
config["nnodes"] = st.number_input(
57-
"Number of nodes to use for distributed training (nnodes)", min_value=1,
30+
"Number of processes to launch on each node (nproc_per_node)", min_value=1,
5831
)
32+
config["nnodes"] = st.number_input("Number of nodes to use for distributed training (nnodes)", min_value=1,)
5933
if config["nnodes"] > 1:
6034
st.info(
6135
"The following options are only supported by torch.distributed, namely 'gloo' and 'nccl' backends."
6236
" For other backends, please specify spawn_kwargs in main.py"
6337
)
6438
config["node_rank"] = st.number_input(
65-
"Rank of the node for multi-node distributed training (node_rank)",
66-
min_value=0,
39+
"Rank of the node for multi-node distributed training (node_rank)", min_value=0,
6740
)
6841
if config["node_rank"] > (config["nnodes"] - 1):
69-
st.error(
70-
f"node_rank should be between 0 and {config['nnodes'] - 1}"
71-
)
42+
st.error(f"node_rank should be between 0 and {config['nnodes'] - 1}")
7243
config["master_addr"] = st.text_input(
73-
"Master node TCP/IP address for torch native backends (master_addr)",
74-
"'127.0.0.1'",
44+
"Master node TCP/IP address for torch native backends (master_addr)", "'127.0.0.1'",
7545
)
7646
st.warning("Please include single quote in master_addr.")
7747
config["master_port"] = st.text_input(

templates/base/utils.py.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ DEFAULTS = {
9090
"help": "rank of the node for multi-node distributed training ({{ node_rank }})",
9191
},
9292
"master_addr": {
93-
"default": {{ master_addr|safe }},
93+
"default": {{ master_addr }},
9494
"type": str,
9595
"help": "master node TCP/IP address for torch native backends ({{ master_addr }})",
9696
},

0 commit comments

Comments
 (0)