7
7
8
8
import ignite .distributed as idist
9
9
import torch
10
- import yaml
11
10
from ignite .contrib .engines import common
12
11
from ignite .engine import Engine
13
12
35
34
36
35
#::: } :::#
37
36
from ignite .utils import setup_logger
37
+ from omegaconf import DictConfig , OmegaConf
38
38
39
39
40
40
def get_default_parser ():
@@ -57,17 +57,11 @@ def setup_config(parser=None):
57
57
args = parser .parse_args ()
58
58
config_path = args .config
59
59
60
- with open (config_path , "r" ) as f :
61
- config = yaml .safe_load (f .read ())
60
+ config = OmegaConf .load (config_path )
62
61
63
- optional_attributes = ["train_epoch_length" , "eval_epoch_length" ]
64
- for attr in optional_attributes :
65
- config [attr ] = config .get (attr , None )
62
+ config .backend = args .backend
66
63
67
- for k , v in config .items ():
68
- setattr (args , k , v )
69
-
70
- return args
64
+ return DictConfig (config )
71
65
72
66
73
67
def log_metrics (engine : Engine , tag : str ) -> None :
@@ -138,6 +132,12 @@ def setup_output_dir(config: Any, rank: int) -> Path:
138
132
return Path (idist .broadcast (config .output_dir , src = 0 ))
139
133
140
134
135
+ def save_config (config , output_dir ):
136
+ """Save configuration to config-lock.yaml for result reproducibility."""
137
+ with open (f"{ output_dir } /config-lock.yaml" , "w" ) as f :
138
+ OmegaConf .save (config , f )
139
+
140
+
141
141
def setup_logging (config : Any ) -> Logger :
142
142
"""Setup logger with `ignite.utils.setup_logger()`.
143
143
0 commit comments