Skip to content

Commit bc9723a

Browse files
liulehuijustinvyu
andauthored
[train][tune] Fix LightGBM v2 callbacks for Tune only usage (#57042)
1. in the ray train [revamp REP](https://github.com/ray-project/enhancements/blob/main/reps/2024-10-18-train-tune-api-revamp/2024-10-18-train-tune-api-revamp.md#tune-only-usage), we decouple the ray train/ray tune dependency. 2. Hence, when using RayTrainReportCallback when reporting metrics or checkpoint: the v2 context api will throw RuntimeError that TrainFnUtils is not found. 3. in this PR, refactor the Callback by inheriting the same base class but using `ray.tune.report` for tune only and `ray.train.report` for `RayTrainReportCallback` based on migration example [here](https://github.com/ray-project/enhancements/blob/main/reps/2024-10-18-train-tune-api-revamp/2024-10-18-train-tune-api-revamp.md#tune-only-usage) to further differentiate these callbacks. --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: Justin Yu <justinvyu@anyscale.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com>
1 parent 4b35b3d commit bc9723a

File tree

3 files changed

+179
-77
lines changed

3 files changed

+179
-77
lines changed

ci/lint/pydoclint-baseline.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,9 +1881,6 @@ python/ray/train/horovod/horovod_trainer.py
18811881
DOC104: Method `HorovodTrainer.__init__`: Arguments are the same in the docstring and the function signature, but are in a different order.
18821882
DOC105: Method `HorovodTrainer.__init__`: Argument names match, but type hints in these args do not match: train_loop_per_worker, train_loop_config, horovod_config, scaling_config, dataset_config, run_config, datasets, metadata, resume_from_checkpoint
18831883
--------------------
1884-
python/ray/train/lightgbm/_lightgbm_utils.py
1885-
DOC201: Method `RayTrainReportCallback.get_model` does not have a return section in docstring
1886-
--------------------
18871884
python/ray/train/lightgbm/lightgbm_predictor.py
18881885
DOC201: Method `LightGBMPredictor.from_checkpoint` does not have a return section in docstring
18891886
--------------------
Lines changed: 105 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import tempfile
2+
from abc import abstractmethod
23
from contextlib import contextmanager
34
from pathlib import Path
45
from typing import Callable, Dict, List, Optional, Union
@@ -12,66 +13,7 @@
1213
from ray.util.annotations import PublicAPI
1314

1415

15-
@PublicAPI(stability="beta")
16-
class RayTrainReportCallback:
17-
"""Creates a callback that reports metrics and checkpoints model.
18-
19-
Args:
20-
metrics: Metrics to report. If this is a list,
21-
each item should be a metric key reported by LightGBM,
22-
and it will be reported to Ray Train/Tune under the same name.
23-
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
24-
which can be used to rename LightGBM default metrics.
25-
filename: Customize the saved checkpoint file type by passing
26-
a filename. Defaults to "model.txt".
27-
frequency: How often to save checkpoints, in terms of iterations.
28-
Defaults to 0 (no checkpoints are saved during training).
29-
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
30-
results_postprocessing_fn: An optional Callable that takes in
31-
the metrics dict that will be reported (after it has been flattened)
32-
and returns a modified dict.
33-
34-
Examples
35-
--------
36-
37-
Reporting checkpoints and metrics to Ray Tune when running many
38-
independent xgboost trials (without data parallelism within a trial).
39-
40-
.. testcode::
41-
:skipif: True
42-
43-
import lightgbm
44-
45-
from ray.train.lightgbm import RayTrainReportCallback
46-
47-
config = {
48-
# ...
49-
"metric": ["binary_logloss", "binary_error"],
50-
}
51-
52-
# Report only log loss to Tune after each validation epoch.
53-
bst = lightgbm.train(
54-
...,
55-
callbacks=[
56-
RayTrainReportCallback(
57-
metrics={"loss": "eval-binary_logloss"}, frequency=1
58-
)
59-
],
60-
)
61-
62-
Loading a model from a checkpoint reported by this callback.
63-
64-
.. testcode::
65-
:skipif: True
66-
67-
from ray.train.lightgbm import RayTrainReportCallback
68-
69-
# Get a `Checkpoint` object that is saved by the callback during training.
70-
result = trainer.fit()
71-
booster = RayTrainReportCallback.get_model(result.checkpoint)
72-
73-
"""
74-
16+
class RayReportCallback:
7517
CHECKPOINT_NAME = "model.txt"
7618

7719
def __init__(
@@ -103,6 +45,8 @@ def get_model(
10345
The checkpoint should be saved by an instance of this callback.
10446
filename: The filename to load the model from, which should match
10547
the filename used when creating the callback.
48+
Returns:
49+
The model loaded from the checkpoint.
10650
"""
10751
with checkpoint.as_directory() as checkpoint_path:
10852
return Booster(model_file=Path(checkpoint_path, filename).as_posix())
@@ -140,14 +84,29 @@ def _get_eval_result(self, env: CallbackEnv) -> dict:
14084
eval_result[data_name][eval_name + "-stdv"] = stdv
14185
return eval_result
14286

143-
@contextmanager
87+
@abstractmethod
14488
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
145-
if ray.train.get_context().get_world_rank() in (0, None):
146-
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
147-
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
148-
yield Checkpoint.from_directory(temp_checkpoint_dir)
149-
else:
150-
yield None
89+
"""Get checkpoint from model.
90+
91+
This method needs to be implemented by subclasses.
92+
"""
93+
raise NotImplementedError
94+
95+
@abstractmethod
96+
def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
97+
"""Save checkpoint and report metrics corresonding to this checkpoint.
98+
99+
This method needs to be implemented by subclasses.
100+
"""
101+
raise NotImplementedError
102+
103+
@abstractmethod
104+
def _report_metrics(self, report_dict: Dict):
105+
"""Report Metrics.
106+
107+
This method needs to be implemented by subclasses.
108+
"""
109+
raise NotImplementedError
151110

152111
def __call__(self, env: CallbackEnv) -> None:
153112
eval_result = self._get_eval_result(env)
@@ -164,7 +123,83 @@ def __call__(self, env: CallbackEnv) -> None:
164123
should_checkpoint = should_checkpoint_at_end or should_checkpoint_with_frequency
165124

166125
if should_checkpoint:
167-
with self._get_checkpoint(model=env.model) as checkpoint:
168-
ray.train.report(report_dict, checkpoint=checkpoint)
126+
self._save_and_report_checkpoint(report_dict, env.model)
169127
else:
170-
ray.train.report(report_dict)
128+
self._report_metrics(report_dict)
129+
130+
131+
@PublicAPI(stability="beta")
132+
class RayTrainReportCallback(RayReportCallback):
133+
"""Creates a callback that reports metrics and checkpoints model.
134+
135+
Args:
136+
metrics: Metrics to report. If this is a list,
137+
each item should be a metric key reported by LightGBM,
138+
and it will be reported to Ray Train/Tune under the same name.
139+
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
140+
which can be used to rename LightGBM default metrics.
141+
filename: Customize the saved checkpoint file type by passing
142+
a filename. Defaults to "model.txt".
143+
frequency: How often to save checkpoints, in terms of iterations.
144+
Defaults to 0 (no checkpoints are saved during training).
145+
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
146+
results_postprocessing_fn: An optional Callable that takes in
147+
the metrics dict that will be reported (after it has been flattened)
148+
and returns a modified dict.
149+
150+
Examples
151+
--------
152+
153+
Reporting checkpoints and metrics to Ray Tune when running many
154+
independent LightGBM trials (without data parallelism within a trial).
155+
156+
.. testcode::
157+
:skipif: True
158+
159+
import lightgbm
160+
161+
from ray.train.lightgbm import RayTrainReportCallback
162+
163+
config = {
164+
# ...
165+
"metric": ["binary_logloss", "binary_error"],
166+
}
167+
168+
# Report only log loss to Tune after each validation epoch.
169+
bst = lightgbm.train(
170+
...,
171+
callbacks=[
172+
RayTrainReportCallback(
173+
metrics={"loss": "eval-binary_logloss"}, frequency=1
174+
)
175+
],
176+
)
177+
178+
Loading a model from a checkpoint reported by this callback.
179+
180+
.. testcode::
181+
:skipif: True
182+
183+
from ray.train.lightgbm import RayTrainReportCallback
184+
185+
# Get a `Checkpoint` object that is saved by the callback during training.
186+
result = trainer.fit()
187+
booster = RayTrainReportCallback.get_model(result.checkpoint)
188+
189+
"""
190+
191+
@contextmanager
192+
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
193+
if ray.train.get_context().get_world_rank() in (0, None):
194+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
195+
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
196+
yield Checkpoint.from_directory(temp_checkpoint_dir)
197+
else:
198+
yield None
199+
200+
def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
201+
with self._get_checkpoint(model=model) as checkpoint:
202+
ray.train.report(report_dict, checkpoint=checkpoint)
203+
204+
def _report_metrics(self, report_dict: Dict):
205+
ray.train.report(report_dict)

python/ray/tune/integration/lightgbm.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,77 @@
1-
from ray.train.lightgbm import ( # noqa: F401
2-
RayTrainReportCallback as TuneReportCheckpointCallback,
3-
)
4-
from ray.util.annotations import Deprecated
1+
import tempfile
2+
from contextlib import contextmanager
3+
from pathlib import Path
4+
from typing import Dict, Optional
5+
6+
from lightgbm import Booster
7+
8+
import ray.tune
9+
from ray.train.lightgbm._lightgbm_utils import RayReportCallback
10+
from ray.tune import Checkpoint
11+
from ray.util.annotations import Deprecated, PublicAPI
12+
13+
14+
@PublicAPI(stability="beta")
15+
class TuneReportCheckpointCallback(RayReportCallback):
16+
"""Creates a callback that reports metrics and checkpoints model.
17+
18+
Args:
19+
metrics: Metrics to report. If this is a list,
20+
each item should be a metric key reported by LightGBM,
21+
and it will be reported to Ray Train/Tune under the same name.
22+
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
23+
which can be used to rename LightGBM default metrics.
24+
filename: Customize the saved checkpoint file type by passing
25+
a filename. Defaults to "model.txt".
26+
frequency: How often to save checkpoints, in terms of iterations.
27+
Defaults to 0 (no checkpoints are saved during training).
28+
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
29+
results_postprocessing_fn: An optional Callable that takes in
30+
the metrics dict that will be reported (after it has been flattened)
31+
and returns a modified dict.
32+
33+
Examples
34+
--------
35+
36+
Reporting checkpoints and metrics to Ray Tune when running many
37+
independent LightGBM trials (without data parallelism within a trial).
38+
39+
.. testcode::
40+
:skipif: True
41+
42+
import lightgbm
43+
44+
from ray.tune.integration.lightgbm import TuneReportCheckpointCallback
45+
46+
config = {
47+
# ...
48+
"metric": ["binary_logloss", "binary_error"],
49+
}
50+
51+
# Report only log loss to Tune after each validation epoch.
52+
bst = lightgbm.train(
53+
...,
54+
callbacks=[
55+
TuneReportCheckpointCallback(
56+
metrics={"loss": "eval-binary_logloss"}, frequency=1
57+
)
58+
],
59+
)
60+
61+
"""
62+
63+
@contextmanager
64+
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
65+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
66+
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
67+
yield Checkpoint.from_directory(temp_checkpoint_dir)
68+
69+
def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
70+
with self._get_checkpoint(model=model) as checkpoint:
71+
ray.tune.report(report_dict, checkpoint=checkpoint)
72+
73+
def _report_metrics(self, report_dict: Dict):
74+
ray.tune.report(report_dict)
575

676

777
@Deprecated

0 commit comments

Comments
 (0)