Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 105 additions & 70 deletions python/ray/train/lightgbm/_lightgbm_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tempfile
from abc import abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
Expand All @@ -12,66 +13,7 @@
from ray.util.annotations import PublicAPI


@PublicAPI(stability="beta")
class RayTrainReportCallback:
"""Creates a callback that reports metrics and checkpoints model.

Args:
metrics: Metrics to report. If this is a list,
each item should be a metric key reported by LightGBM,
and it will be reported to Ray Train/Tune under the same name.
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
which can be used to rename LightGBM default metrics.
filename: Customize the saved checkpoint file type by passing
a filename. Defaults to "model.txt".
frequency: How often to save checkpoints, in terms of iterations.
Defaults to 0 (no checkpoints are saved during training).
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
results_postprocessing_fn: An optional Callable that takes in
the metrics dict that will be reported (after it has been flattened)
and returns a modified dict.

Examples
--------

Reporting checkpoints and metrics to Ray Tune when running many
independent xgboost trials (without data parallelism within a trial).

.. testcode::
:skipif: True

import lightgbm

from ray.train.lightgbm import RayTrainReportCallback

config = {
# ...
"metric": ["binary_logloss", "binary_error"],
}

# Report only log loss to Tune after each validation epoch.
bst = lightgbm.train(
...,
callbacks=[
RayTrainReportCallback(
metrics={"loss": "eval-binary_logloss"}, frequency=1
)
],
)

Loading a model from a checkpoint reported by this callback.

.. testcode::
:skipif: True

from ray.train.lightgbm import RayTrainReportCallback

# Get a `Checkpoint` object that is saved by the callback during training.
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)

"""

class RayReportCallback:
CHECKPOINT_NAME = "model.txt"

def __init__(
Expand Down Expand Up @@ -103,6 +45,8 @@ def get_model(
The checkpoint should be saved by an instance of this callback.
filename: The filename to load the model from, which should match
the filename used when creating the callback.
Returns:
The model loaded from the checkpoint.
"""
with checkpoint.as_directory() as checkpoint_path:
return Booster(model_file=Path(checkpoint_path, filename).as_posix())
Expand Down Expand Up @@ -140,14 +84,29 @@ def _get_eval_result(self, env: CallbackEnv) -> dict:
eval_result[data_name][eval_name + "-stdv"] = stdv
return eval_result

@contextmanager
@abstractmethod
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
if ray.train.get_context().get_world_rank() in (0, None):
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint.from_directory(temp_checkpoint_dir)
else:
yield None
"""Get checkpoint from model.

This method needs to be implemented by subclasses.
"""
raise NotImplementedError
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Missing Context Manager Decorator

The abstract method _get_checkpoint lacks the @contextmanager decorator. This creates an interface mismatch, as its concrete implementations use this decorator and the method is consistently called as a context manager in _save_and_report_checkpoint.

Fix in Cursor Fix in Web


@abstractmethod
def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
"""Save checkpoint and report metrics corresonding to this checkpoint.

This method needs to be implemented by subclasses.
"""
raise NotImplementedError

@abstractmethod
def _report_metrics(self, report_dict: Dict):
"""Report Metrics.

This method needs to be implemented by subclasses.
"""
raise NotImplementedError

def __call__(self, env: CallbackEnv) -> None:
eval_result = self._get_eval_result(env)
Expand All @@ -164,7 +123,83 @@ def __call__(self, env: CallbackEnv) -> None:
should_checkpoint = should_checkpoint_at_end or should_checkpoint_with_frequency

if should_checkpoint:
with self._get_checkpoint(model=env.model) as checkpoint:
ray.train.report(report_dict, checkpoint=checkpoint)
self._save_and_report_checkpoint(report_dict, env.model)
else:
ray.train.report(report_dict)
self._report_metrics(report_dict)


@PublicAPI(stability="beta")
class RayTrainReportCallback(RayReportCallback):
"""Creates a callback that reports metrics and checkpoints model.

Args:
metrics: Metrics to report. If this is a list,
each item should be a metric key reported by LightGBM,
and it will be reported to Ray Train/Tune under the same name.
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
which can be used to rename LightGBM default metrics.
filename: Customize the saved checkpoint file type by passing
a filename. Defaults to "model.txt".
frequency: How often to save checkpoints, in terms of iterations.
Defaults to 0 (no checkpoints are saved during training).
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
results_postprocessing_fn: An optional Callable that takes in
the metrics dict that will be reported (after it has been flattened)
and returns a modified dict.

Examples
--------

Reporting checkpoints and metrics to Ray Tune when running many
independent LightGBM trials (without data parallelism within a trial).

.. testcode::
:skipif: True

import lightgbm

from ray.train.lightgbm import RayTrainReportCallback

config = {
# ...
"metric": ["binary_logloss", "binary_error"],
}

# Report only log loss to Tune after each validation epoch.
bst = lightgbm.train(
...,
callbacks=[
RayTrainReportCallback(
metrics={"loss": "eval-binary_logloss"}, frequency=1
)
],
)

Loading a model from a checkpoint reported by this callback.

.. testcode::
:skipif: True

from ray.train.lightgbm import RayTrainReportCallback

# Get a `Checkpoint` object that is saved by the callback during training.
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)

"""

@contextmanager
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
if ray.train.get_context().get_world_rank() in (0, None):
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint.from_directory(temp_checkpoint_dir)
else:
yield None

def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
with self._get_checkpoint(model=model) as checkpoint:
ray.train.report(report_dict, checkpoint=checkpoint)

def _report_metrics(self, report_dict: Dict):
ray.train.report(report_dict)
89 changes: 85 additions & 4 deletions python/ray/tune/integration/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,88 @@
from ray.train.lightgbm import ( # noqa: F401
RayTrainReportCallback as TuneReportCheckpointCallback,
)
from ray.util.annotations import Deprecated
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Optional

from lightgbm.basic import Booster

import ray.tune
from ray.train.lightgbm._lightgbm_utils import RayReportCallback
from ray.tune import Checkpoint
from ray.util.annotations import Deprecated, PublicAPI


@PublicAPI(stability="beta")
class TuneReportCheckpointCallback(RayReportCallback):
"""Creates a callback that reports metrics and checkpoints model.

Args:
metrics: Metrics to report. If this is a list,
each item should be a metric key reported by LightGBM,
and it will be reported to Ray Train/Tune under the same name.
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
which can be used to rename LightGBM default metrics.
filename: Customize the saved checkpoint file type by passing
a filename. Defaults to "model.txt".
frequency: How often to save checkpoints, in terms of iterations.
Defaults to 0 (no checkpoints are saved during training).
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
results_postprocessing_fn: An optional Callable that takes in
the metrics dict that will be reported (after it has been flattened)
and returns a modified dict.

Examples
--------

Reporting checkpoints and metrics to Ray Tune when running many
independent LightGBM trials (without data parallelism within a trial).

.. testcode::
:skipif: True

import lightgbm

from ray.tune.integration.lightgbm import TuneReportCheckpointCallback

config = {
# ...
"metric": ["binary_logloss", "binary_error"],
}

# Report only log loss to Tune after each validation epoch.
bst = lightgbm.train(
...,
callbacks=[
TuneReportCheckpointCallback(
metrics={"loss": "eval-binary_logloss"}, frequency=1
)
],
)

Loading a model from a checkpoint reported by this callback.

.. testcode::
:skipif: True

from ray.tune.integration.lightgbm import TuneReportCheckpointCallback

# Get a `Checkpoint` object that is saved by the callback during training.
result = trainer.fit()
booster = TuneReportCheckpointCallback.get_model(result.checkpoint)

"""

@contextmanager
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint(temp_checkpoint_dir)

def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
with self._get_checkpoint(model=model) as checkpoint:
ray.tune.report(report_dict, checkpoint=checkpoint)

def _report_metrics(self, report_dict: Dict):
ray.tune.report(report_dict)


@Deprecated
Expand Down