11import tempfile
2+ from abc import abstractmethod
23from contextlib import contextmanager
34from pathlib import Path
45from typing import Callable , Dict , List , Optional , Union
1213from ray .util .annotations import PublicAPI
1314
1415
15- @PublicAPI (stability = "beta" )
16- class RayTrainReportCallback :
17- """Creates a callback that reports metrics and checkpoints model.
18-
19- Args:
20- metrics: Metrics to report. If this is a list,
21- each item should be a metric key reported by LightGBM,
22- and it will be reported to Ray Train/Tune under the same name.
23- This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
24- which can be used to rename LightGBM default metrics.
25- filename: Customize the saved checkpoint file type by passing
26- a filename. Defaults to "model.txt".
27- frequency: How often to save checkpoints, in terms of iterations.
28- Defaults to 0 (no checkpoints are saved during training).
29- checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
30- results_postprocessing_fn: An optional Callable that takes in
31- the metrics dict that will be reported (after it has been flattened)
32- and returns a modified dict.
33-
34- Examples
35- --------
36-
37- Reporting checkpoints and metrics to Ray Tune when running many
38- independent xgboost trials (without data parallelism within a trial).
39-
40- .. testcode::
41- :skipif: True
42-
43- import lightgbm
44-
45- from ray.train.lightgbm import RayTrainReportCallback
46-
47- config = {
48- # ...
49- "metric": ["binary_logloss", "binary_error"],
50- }
51-
52- # Report only log loss to Tune after each validation epoch.
53- bst = lightgbm.train(
54- ...,
55- callbacks=[
56- RayTrainReportCallback(
57- metrics={"loss": "eval-binary_logloss"}, frequency=1
58- )
59- ],
60- )
61-
62- Loading a model from a checkpoint reported by this callback.
63-
64- .. testcode::
65- :skipif: True
66-
67- from ray.train.lightgbm import RayTrainReportCallback
68-
69- # Get a `Checkpoint` object that is saved by the callback during training.
70- result = trainer.fit()
71- booster = RayTrainReportCallback.get_model(result.checkpoint)
72-
73- """
74-
16+ class RayReportCallback :
7517 CHECKPOINT_NAME = "model.txt"
7618
7719 def __init__ (
@@ -103,6 +45,8 @@ def get_model(
10345 The checkpoint should be saved by an instance of this callback.
10446 filename: The filename to load the model from, which should match
10547 the filename used when creating the callback.
48+ Returns:
49+ The model loaded from the checkpoint.
10650 """
10751 with checkpoint .as_directory () as checkpoint_path :
10852 return Booster (model_file = Path (checkpoint_path , filename ).as_posix ())
@@ -140,14 +84,29 @@ def _get_eval_result(self, env: CallbackEnv) -> dict:
14084 eval_result [data_name ][eval_name + "-stdv" ] = stdv
14185 return eval_result
14286
143- @contextmanager
87+ @abstractmethod
14488 def _get_checkpoint (self , model : Booster ) -> Optional [Checkpoint ]:
145- if ray .train .get_context ().get_world_rank () in (0 , None ):
146- with tempfile .TemporaryDirectory () as temp_checkpoint_dir :
147- model .save_model (Path (temp_checkpoint_dir , self ._filename ).as_posix ())
148- yield Checkpoint .from_directory (temp_checkpoint_dir )
149- else :
150- yield None
89+ """Get checkpoint from model.
90+
91+ This method needs to be implemented by subclasses.
92+ """
93+ raise NotImplementedError
94+
95+ @abstractmethod
96+ def _save_and_report_checkpoint (self , report_dict : Dict , model : Booster ):
97+ """Save checkpoint and report metrics corresonding to this checkpoint.
98+
99+ This method needs to be implemented by subclasses.
100+ """
101+ raise NotImplementedError
102+
103+ @abstractmethod
104+ def _report_metrics (self , report_dict : Dict ):
105+ """Report Metrics.
106+
107+ This method needs to be implemented by subclasses.
108+ """
109+ raise NotImplementedError
151110
152111 def __call__ (self , env : CallbackEnv ) -> None :
153112 eval_result = self ._get_eval_result (env )
@@ -164,7 +123,83 @@ def __call__(self, env: CallbackEnv) -> None:
164123 should_checkpoint = should_checkpoint_at_end or should_checkpoint_with_frequency
165124
166125 if should_checkpoint :
167- with self ._get_checkpoint (model = env .model ) as checkpoint :
168- ray .train .report (report_dict , checkpoint = checkpoint )
126+ self ._save_and_report_checkpoint (report_dict , env .model )
169127 else :
170- ray .train .report (report_dict )
128+ self ._report_metrics (report_dict )
129+
130+
131+ @PublicAPI (stability = "beta" )
132+ class RayTrainReportCallback (RayReportCallback ):
133+ """Creates a callback that reports metrics and checkpoints model.
134+
135+ Args:
136+ metrics: Metrics to report. If this is a list,
137+ each item should be a metric key reported by LightGBM,
138+ and it will be reported to Ray Train/Tune under the same name.
139+ This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
140+ which can be used to rename LightGBM default metrics.
141+ filename: Customize the saved checkpoint file type by passing
142+ a filename. Defaults to "model.txt".
143+ frequency: How often to save checkpoints, in terms of iterations.
144+ Defaults to 0 (no checkpoints are saved during training).
145+ checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
146+ results_postprocessing_fn: An optional Callable that takes in
147+ the metrics dict that will be reported (after it has been flattened)
148+ and returns a modified dict.
149+
150+ Examples
151+ --------
152+
153+ Reporting checkpoints and metrics to Ray Tune when running many
154+ independent LightGBM trials (without data parallelism within a trial).
155+
156+ .. testcode::
157+ :skipif: True
158+
159+ import lightgbm
160+
161+ from ray.train.lightgbm import RayTrainReportCallback
162+
163+ config = {
164+ # ...
165+ "metric": ["binary_logloss", "binary_error"],
166+ }
167+
168+ # Report only log loss to Tune after each validation epoch.
169+ bst = lightgbm.train(
170+ ...,
171+ callbacks=[
172+ RayTrainReportCallback(
173+ metrics={"loss": "eval-binary_logloss"}, frequency=1
174+ )
175+ ],
176+ )
177+
178+ Loading a model from a checkpoint reported by this callback.
179+
180+ .. testcode::
181+ :skipif: True
182+
183+ from ray.train.lightgbm import RayTrainReportCallback
184+
185+ # Get a `Checkpoint` object that is saved by the callback during training.
186+ result = trainer.fit()
187+ booster = RayTrainReportCallback.get_model(result.checkpoint)
188+
189+ """
190+
191+ @contextmanager
192+ def _get_checkpoint (self , model : Booster ) -> Optional [Checkpoint ]:
193+ if ray .train .get_context ().get_world_rank () in (0 , None ):
194+ with tempfile .TemporaryDirectory () as temp_checkpoint_dir :
195+ model .save_model (Path (temp_checkpoint_dir , self ._filename ).as_posix ())
196+ yield Checkpoint .from_directory (temp_checkpoint_dir )
197+ else :
198+ yield None
199+
200+ def _save_and_report_checkpoint (self , report_dict : Dict , model : Booster ):
201+ with self ._get_checkpoint (model = model ) as checkpoint :
202+ ray .train .report (report_dict , checkpoint = checkpoint )
203+
204+ def _report_metrics (self , report_dict : Dict ):
205+ ray .train .report (report_dict )
0 commit comments