Skip to content

Commit 07b1115

Browse files
author
Jeff Yang
authored
feat: allow users give project name (#35)
* feat: allow users give project name * iterdir -> rglob
1 parent ca80e3d commit 07b1115

14 files changed

+27
-26
lines changed

app/codegen.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@ class CodeGenerator:
1010
def __init__(self, templates_dir: str = "./templates", dist_dir: str = "./dist"):
1111
self.templates_dir = Path(templates_dir)
1212
self.dist_dir = Path(dist_dir)
13-
self.template_list = [p.stem for p in self.templates_dir.iterdir() if p.is_dir() and not p.stem.startswith("_")]
14-
self.rendered_code = {t: {} for t in self.template_list}
13+
self.rendered_code = {}
1514
self.available_archive_formats = [x[0] for x in shutil.get_archive_formats()[::-1]]
1615

17-
def render_templates(self, template_name: str, config: dict):
16+
def render_templates(self, template_name: str, project_name: str, config: dict):
1817
"""Renders all the templates files from template folder for the given config."""
19-
self.rendered_code[template_name] = {} # clean up the rendered code for given template
2018
# loading the template files based on given template and from the _base folder
2119
# since we are using some templates from _base folder
2220
loader = FileSystemLoader([self.templates_dir / "_base", self.templates_dir / template_name])
@@ -27,11 +25,11 @@ def render_templates(self, template_name: str, config: dict):
2725
)
2826
for fname in env.list_templates(filter_func=lambda x: not x.startswith("_")):
2927
code = env.get_template(fname).render(**config)
30-
fname = fname.replace(".pyi", ".py")
31-
self.rendered_code[template_name][fname] = code
28+
fname = fname.replace(".pyi", ".py").replace(template_name, project_name)
29+
self.rendered_code[fname] = code
3230
yield fname, code
3331

34-
def make_and_write(self, template_name: str):
32+
def make_and_write(self, template_name: str, project_name: str):
3533
"""Make the directories first and write to the files"""
3634
for p in (self.templates_dir / template_name).rglob("*"):
3735
if not p.stem.startswith("_") and p.is_dir():
@@ -42,20 +40,21 @@ def make_and_write(self, template_name: str):
4240
else:
4341
p = template_name
4442

43+
p = p.replace(template_name, project_name)
4544
if not (self.dist_dir / p).is_dir():
4645
(self.dist_dir / p).mkdir(parents=True, exist_ok=True)
4746

48-
for fname, code in self.rendered_code[template_name].items():
49-
(self.dist_dir / template_name / fname).write_text(code)
47+
for fname, code in self.rendered_code.items():
48+
(self.dist_dir / project_name / fname).write_text(code)
5049

51-
def make_archive(self, template_name, archive_format):
50+
def make_archive(self, template_name, project_name: str, archive_format):
5251
"""Creates dist dir with generated code, then makes the archive."""
5352

54-
self.make_and_write(template_name)
53+
self.make_and_write(template_name, project_name)
5554
archive_fname = shutil.make_archive(
56-
base_name=template_name,
55+
base_name=project_name,
5756
root_dir=self.dist_dir,
5857
format=archive_format,
59-
base_dir=template_name,
58+
base_dir=project_name,
6059
)
6160
return shutil.move(archive_fname, self.dist_dir / archive_fname.split("/")[-1])

app/streamlit_app.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,13 @@ def sidebar(self, template_list=None, config=None):
5353
template_list = template_list or []
5454
st.markdown("### Choose a Template")
5555
self.template_name = st.selectbox("Available Templates are:", options=template_list)
56+
self.project_name = st.text_input("Project Name:", "project_1")
5657
self.template_name = FOLDER_TO_TEMPLATE_NAME[self.template_name]
5758
with st.sidebar:
5859
if self.template_name:
5960
config = config(self.template_name)
6061
self.config = config.get_configs()
62+
self.config["project_name"] = self.project_name
6163
else:
6264
self.config = {}
6365

@@ -120,7 +122,7 @@ def config(template_name):
120122

121123
def add_content(self):
122124
"""Get generated/rendered code from the codegen."""
123-
content = [*self.codegen.render_templates(self.template_name, self.config)]
125+
content = [*self.codegen.render_templates(self.template_name, self.project_name, self.config)]
124126
if st.checkbox("View rendered code ?", value=True):
125127
for fname, code in content:
126128
if len(code): # don't show files which don't have content in them
@@ -135,7 +137,7 @@ def add_download(self):
135137
# https://github.com/streamlit/streamlit/issues/400
136138
# https://github.com/streamlit/streamlit/issues/400#issuecomment-648580840
137139
if st.button("Generate an archive"):
138-
archive_fname = self.codegen.make_archive(self.template_name, archive_format)
140+
archive_fname = self.codegen.make_archive(self.template_name, self.project_name, archive_format)
139141
# this is where streamlit serves static files
140142
# ~/site-packages/streamlit/static/static/
141143
dist_path = Path(st.__path__[0]) / "static/static/dist"
@@ -144,7 +146,7 @@ def add_download(self):
144146
shutil.copy(archive_fname, dist_path)
145147
st.success(f"Download link : [{archive_fname}](./static/{archive_fname})")
146148
with col2:
147-
self.render_directory(Path(self.codegen.dist_dir, self.template_name))
149+
self.render_directory(Path(self.codegen.dist_dir, self.project_name))
148150

149151
def run(self):
150152
self.add_sidebar()

templates/single/setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def dependencies(fname):
2424

2525
setup(
2626
# Metadata
27-
name="single_cg",
27+
name="{{project_name}}",
2828
version=VERSION,
2929
long_description_content_type="text/markdown",
3030
long_description=readme,

templates/single/single_cg/engines.pyi renamed to templates/single/single/engines.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ from ignite.engine import Engine
88
from torch.cuda.amp import autocast
99
from torch.optim.optimizer import Optimizer
1010

11-
from single_cg.events import TrainEvents, train_events_to_attr
11+
from {{project_name}}.events import TrainEvents, train_events_to_attr
1212

1313

1414
# Edit below functions the way how the model will be training

templates/single/single_cg/main.pyi renamed to templates/single/single/main.pyi

+4-4
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ import ignite.distributed as idist
1010
from ignite.engine.events import Events
1111
from ignite.utils import manual_seed
1212

13-
from single_cg.engines import create_engines
14-
from single_cg.events import TrainEvents
15-
from single_cg.handlers import get_handlers, get_logger
16-
from single_cg.utils import get_default_parser, setup_logging, log_metrics, log_basic_info, initialize, resume_from
13+
from {{project_name}}.engines import create_engines
14+
from {{project_name}}.events import TrainEvents
15+
from {{project_name}}.handlers import get_handlers, get_logger
16+
from {{project_name}}.utils import get_default_parser, setup_logging, log_metrics, log_basic_info, initialize, resume_from
1717

1818

1919
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):

templates/single/tests/test_engines.py renamed to templates/single/tests/test_engines.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import ignite.distributed as idist
77
import torch
88
from ignite.engine.engine import Engine
9-
from single_cg.engines import create_engines, evaluate_function, train_function
10-
from single_cg.events import TrainEvents, train_events_to_attr
9+
from {{project_name}}.engines import create_engines, evaluate_function, train_function
10+
from {{project_name}}.events import TrainEvents, train_events_to_attr
1111
from torch import nn, optim
1212

1313

templates/single/tests/test_handlers.py renamed to templates/single/tests/test_handlers.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ignite.handlers.checkpoint import Checkpoint
1717
from ignite.handlers.early_stopping import EarlyStopping
1818
from ignite.handlers.timing import Timer
19-
from single_cg.handlers import get_handlers, get_logger
19+
from {{project_name}}.handlers import get_handlers, get_logger
2020
from torch import nn, optim
2121

2222

templates/single/tests/test_utils.py renamed to templates/single/tests/test_utils.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from ignite.engine import Engine
99
from ignite.utils import setup_logger
10-
from single_cg.utils import (
10+
from {{project_name}}.utils import (
1111
get_default_parser,
1212
hash_checkpoint,
1313
log_metrics,

0 commit comments

Comments
 (0)