Skip to content

Commit e03208c

Browse files
justinvyuelliot-barn
authored andcommitted
[train] Enable deprecation warning for legacy xgboost/lightgbm trainer APIs (#57280)
Emit the deprecation warning for the legacy XGBoostTrainer API, in preparation for Train V2 migration (which only supports the new custom training function API). --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
1 parent c39ec3f commit e03208c

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

python/ray/train/lightgbm/lightgbm_trainer.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
from ray.train.lightgbm.config import LightGBMConfig
1212
from ray.train.lightgbm.v2 import LightGBMTrainer as SimpleLightGBMTrainer
1313
from ray.train.trainer import GenDataset
14+
from ray.train.utils import _log_deprecation_warning
1415
from ray.util.annotations import PublicAPI
1516

1617
logger = logging.getLogger(__name__)
1718

1819

19-
LEGACY_LIGHTGBMGBM_TRAINER_DEPRECATION_MESSAGE = (
20+
LEGACY_LIGHTGBM_TRAINER_DEPRECATION_MESSAGE = (
2021
"Passing in `lightgbm.train` kwargs such as `params`, `num_boost_round`, "
2122
"`label_column`, etc. to `LightGBMTrainer` is deprecated "
2223
"in favor of the new API which accepts a `train_loop_per_worker` argument, "
@@ -228,15 +229,14 @@ def __init__(
228229
datasets=datasets,
229230
)
230231
train_loop_config = params or {}
231-
# TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API
232-
# elif train_kwargs:
233-
# _log_deprecation_warning(
234-
# "Passing `lightgbm.train` kwargs to `LightGBMTrainer` is deprecated. "
235-
# f"Got kwargs: {train_kwargs.keys()}\n"
236-
# "Please pass in a `train_loop_per_worker` function instead, "
237-
# "which has full flexibility on the call to `lightgbm.train(**kwargs)`. "
238-
# f"{LEGACY_LIGHTGBMGBM_TRAINER_DEPRECATION_MESSAGE}"
239-
# )
232+
elif train_kwargs:
233+
_log_deprecation_warning(
234+
"Passing `lightgbm.train` kwargs to `LightGBMTrainer` is deprecated. "
235+
f"Got kwargs: {train_kwargs.keys()}\n"
236+
"In your training function, you can call `lightgbm.train(**kwargs)` "
237+
"with arbitrary arguments. "
238+
f"{LEGACY_LIGHTGBM_TRAINER_DEPRECATION_MESSAGE}"
239+
)
240240

241241
super(LightGBMTrainer, self).__init__(
242242
train_loop_per_worker=train_loop_per_worker,
@@ -278,8 +278,7 @@ def _get_legacy_train_fn_per_worker(
278278

279279
num_boost_round = num_boost_round or 10
280280

281-
# TODO: [Deprecated] Legacy LightGBMTrainer API
282-
# _log_deprecation_warning(LEGACY_LIGHTGBMGBM_TRAINER_DEPRECATION_MESSAGE)
281+
_log_deprecation_warning(LEGACY_LIGHTGBM_TRAINER_DEPRECATION_MESSAGE)
283282

284283
# Initialize a default Ray Train metrics/checkpoint reporting callback if needed
285284
callbacks = lightgbm_train_kwargs.get("callbacks", [])

python/ray/train/xgboost/xgboost_trainer.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ray.train import Checkpoint
1010
from ray.train.constants import TRAIN_DATASET_KEY
1111
from ray.train.trainer import GenDataset
12+
from ray.train.utils import _log_deprecation_warning
1213
from ray.train.xgboost import RayTrainReportCallback, XGBoostConfig
1314
from ray.train.xgboost.v2 import XGBoostTrainer as SimpleXGBoostTrainer
1415
from ray.util.annotations import PublicAPI
@@ -19,7 +20,7 @@
1920
LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE = (
2021
"Passing in `xgboost.train` kwargs such as `params`, `num_boost_round`, "
2122
"`label_column`, etc. to `XGBoostTrainer` is deprecated "
22-
"in favor of the new API which accepts a ``train_loop_per_worker`` argument, "
23+
"in favor of the new API which accepts a training function, "
2324
"similar to the other DataParallelTrainer APIs (ex: TorchTrainer). "
2425
"See this issue for more context: "
2526
"https://github.com/ray-project/ray/issues/50042"
@@ -228,14 +229,13 @@ def __init__(
228229
datasets=datasets,
229230
)
230231
train_loop_config = params or {}
231-
# TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API
232-
# elif train_kwargs:
233-
# _log_deprecation_warning(
234-
# "Passing `xgboost.train` kwargs to `XGBoostTrainer` is deprecated. "
235-
# "Please pass in a `train_loop_per_worker` function instead, "
236-
# "which has full flexibility on the call to `xgboost.train(**kwargs)`. "
237-
# f"{LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE}"
238-
# )
232+
elif train_kwargs:
233+
_log_deprecation_warning(
234+
"Passing `xgboost.train` kwargs to `XGBoostTrainer` is deprecated. "
235+
"In your training function, you can call `xgboost.train(**kwargs)` "
236+
"with arbitrary arguments. "
237+
f"{LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE}"
238+
)
239239

240240
super(XGBoostTrainer, self).__init__(
241241
train_loop_per_worker=train_loop_per_worker,
@@ -277,8 +277,7 @@ def _get_legacy_train_fn_per_worker(
277277

278278
num_boost_round = num_boost_round or 10
279279

280-
# TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API
281-
# _log_deprecation_warning(LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE)
280+
_log_deprecation_warning(LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE)
282281

283282
# Initialize a default Ray Train metrics/checkpoint reporting callback if needed
284283
callbacks = xgboost_train_kwargs.get("callbacks", [])

0 commit comments

Comments
 (0)