|
4 | 4 | def get_configs() -> dict:
|
5 | 5 | config = {}
|
6 | 6 | with st.beta_expander("Training Configurations"):
|
7 |
| - st.info( |
8 |
| - "Common base training configurations. Those in the parenthesis are used in the code." |
9 |
| - ) |
| 7 | + st.info("Common base training configurations. Those in the parenthesis are used in the code.") |
10 | 8 |
|
11 | 9 | # group by streamlit function type
|
12 |
| - config["amp_mode"] = st.selectbox( |
13 |
| - "AMP mode (amp_mode)", ("None", "amp", "apex") |
14 |
| - ) |
15 |
| - config["device"] = st.selectbox( |
16 |
| - "Device to use (device)", ("cpu", "cuda", "xla") |
17 |
| - ) |
| 10 | + config["amp_mode"] = st.selectbox("AMP mode (amp_mode)", ("None", "amp", "apex")) |
| 11 | + config["device"] = st.selectbox("Device to use (device)", ("cpu", "cuda", "xla")) |
18 | 12 |
|
19 | 13 | config["data_path"] = st.text_input("Dataset path (data_path)", value="./")
|
20 |
| - config["filepath"] = st.text_input( |
21 |
| - "Logging file path (filepath)", value="./logs" |
22 |
| - ) |
| 14 | + config["filepath"] = st.text_input("Logging file path (filepath)", value="./logs") |
23 | 15 |
|
24 |
| - config["train_batch_size"] = st.number_input( |
25 |
| - "Train batch size (train_batch_size)", min_value=1, value=1 |
26 |
| - ) |
27 |
| - config["eval_batch_size"] = st.number_input( |
28 |
| - "Eval batch size (eval_batch_size)", min_value=1, value=1 |
29 |
| - ) |
30 |
| - config["num_workers"] = st.number_input( |
31 |
| - "Number of workers (num_workers)", min_value=0, value=2 |
32 |
| - ) |
33 |
| - config["max_epochs"] = st.number_input( |
34 |
| - "Maximum epochs to train (max_epochs)", min_value=1, value=2 |
35 |
| - ) |
| 16 | + config["train_batch_size"] = st.number_input("Train batch size (train_batch_size)", min_value=1, value=1) |
| 17 | + config["eval_batch_size"] = st.number_input("Eval batch size (eval_batch_size)", min_value=1, value=1) |
| 18 | + config["num_workers"] = st.number_input("Number of workers (num_workers)", min_value=0, value=2) |
| 19 | + config["max_epochs"] = st.number_input("Maximum epochs to train (max_epochs)", min_value=1, value=2) |
36 | 20 | config["lr"] = st.number_input(
|
37 |
| - "Learning rate used by torch.optim.* (lr)", |
38 |
| - min_value=0.0, |
39 |
| - value=1e-3, |
40 |
| - format="%e", |
| 21 | + "Learning rate used by torch.optim.* (lr)", min_value=0.0, value=1e-3, format="%e", |
41 | 22 | )
|
42 | 23 | config["log_train"] = st.number_input(
|
43 | 24 | "Logging interval of training iterations (log_train)", min_value=0, value=50
|
44 | 25 | )
|
45 |
| - config["log_eval"] = st.number_input( |
46 |
| - "Logging interval of evaluation epoch (log_eval)", min_value=0, value=1 |
47 |
| - ) |
48 |
| - config["seed"] = st.number_input( |
49 |
| - "Seed used in ignite.utils.manual_seed() (seed)", min_value=0, value=666 |
50 |
| - ) |
| 26 | + config["log_eval"] = st.number_input("Logging interval of evaluation epoch (log_eval)", min_value=0, value=1) |
| 27 | + config["seed"] = st.number_input("Seed used in ignite.utils.manual_seed() (seed)", min_value=0, value=666) |
51 | 28 | if st.checkbox("Use distributed training"):
|
52 | 29 | config["nproc_per_node"] = st.number_input(
|
53 |
| - "Number of processes to launch on each node (nproc_per_node)", |
54 |
| - min_value=1, |
55 |
| - ) |
56 |
| - config["nnodes"] = st.number_input( |
57 |
| - "Number of nodes to use for distributed training (nnodes)", min_value=1, |
| 30 | + "Number of processes to launch on each node (nproc_per_node)", min_value=1, |
58 | 31 | )
|
| 32 | + config["nnodes"] = st.number_input("Number of nodes to use for distributed training (nnodes)", min_value=1,) |
59 | 33 | if config["nnodes"] > 1:
|
60 | 34 | st.info(
|
61 | 35 | "The following options are only supported by torch.distributed, namely 'gloo' and 'nccl' backends."
|
62 | 36 | " For other backends, please specify spawn_kwargs in main.py"
|
63 | 37 | )
|
64 | 38 | config["node_rank"] = st.number_input(
|
65 |
| - "Rank of the node for multi-node distributed training (node_rank)", |
66 |
| - min_value=0, |
| 39 | + "Rank of the node for multi-node distributed training (node_rank)", min_value=0, |
67 | 40 | )
|
68 | 41 | if config["node_rank"] > (config["nnodes"] - 1):
|
69 |
| - st.error( |
70 |
| - f"node_rank should be between 0 and {config['nnodes'] - 1}" |
71 |
| - ) |
| 42 | + st.error(f"node_rank should be between 0 and {config['nnodes'] - 1}") |
72 | 43 | config["master_addr"] = st.text_input(
|
73 |
| - "Master node TCP/IP address for torch native backends (master_addr)", |
74 |
| - "'127.0.0.1'", |
| 44 | + "Master node TCP/IP address for torch native backends (master_addr)", "'127.0.0.1'", |
75 | 45 | )
|
76 | 46 | st.warning("Please include single quote in master_addr.")
|
77 | 47 | config["master_port"] = st.text_input(
|
|
0 commit comments