Skip to content

Commit 3be57a8

Browse files
justinvyuxinyuangui2
authored andcommitted
[tune][train] Improve ray.train method deprecations called from Tune functions (ray-project#57810)
Ray Tune users need to stop using `ray.train.get_context()`, `ray.train.get_checkpoint()` and `ray.train.report`. This PR improves the error messages raise if they try to call these instead of the `ray.tune` counterparts. Note that we've already soft-deprecated this usage with a warning message for 6+ months. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: xgui <xgui@anyscale.com>
1 parent 73703b8 commit 3be57a8

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

python/ray/train/v2/api/train_fn_utils.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,16 @@ def train_func(config):
101101
validate_config: Configuration passed to the validate_fn. Can contain info
102102
like the validation dataset.
103103
"""
104+
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
105+
106+
if _in_tune_session():
107+
raise DeprecationWarning(
108+
"`ray.train.report` is deprecated when running in a function "
109+
"passed to Ray Tune. Please use `ray.tune.report` instead. "
110+
"See this issue for more context: "
111+
"https://github.com/ray-project/ray/issues/49454"
112+
)
113+
104114
if delete_local_checkpoint_after_upload is None:
105115
delete_local_checkpoint_after_upload = (
106116
checkpoint_upload_mode._default_delete_local_checkpoint_after_upload()
@@ -130,8 +140,16 @@ def get_context() -> TrainContext:
130140
131141
See the :class:`~ray.train.TrainContext` API reference to see available methods.
132142
"""
133-
# TODO: Return a dummy train context on the controller and driver process
134-
# instead of raising an exception if the train context does not exist.
143+
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
144+
145+
if _in_tune_session():
146+
raise DeprecationWarning(
147+
"`ray.train.get_context` is deprecated when running in a function "
148+
"passed to Ray Tune. Please use `ray.tune.get_context` instead. "
149+
"See this issue for more context: "
150+
"https://github.com/ray-project/ray/issues/49454"
151+
)
152+
135153
return get_train_fn_utils().get_context()
136154

137155

@@ -179,6 +197,16 @@ def train_func(config):
179197
Checkpoint object if the session is currently being resumed.
180198
Otherwise, return None.
181199
"""
200+
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
201+
202+
if _in_tune_session():
203+
raise DeprecationWarning(
204+
"`ray.train.get_checkpoint` is deprecated when running in a function "
205+
"passed to Ray Tune. Please use `ray.tune.get_checkpoint` instead. "
206+
"See this issue for more context: "
207+
"https://github.com/ray-project/ray/issues/49454"
208+
)
209+
182210
return get_train_fn_utils().get_checkpoint()
183211

184212

python/ray/tune/tests/test_api_migrations.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import functools
2+
import importlib
13
import sys
24
import warnings
35

@@ -9,31 +11,48 @@
911
from ray.util.annotations import RayDeprecationWarning
1012

1113

14+
@pytest.fixture(autouse=True)
15+
def enable_v2(monkeypatch):
16+
monkeypatch.setenv("RAY_TRAIN_V2_ENABLED", "1")
17+
importlib.reload(ray.train)
18+
yield
19+
20+
1221
@pytest.fixture(autouse=True)
1322
def enable_v2_migration_deprecation_messages(monkeypatch):
1423
monkeypatch.setenv(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, "1")
1524
yield
1625
monkeypatch.delenv(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR)
1726

1827

19-
def test_trainable_fn_utils(tmp_path):
28+
@pytest.mark.parametrize("v2_enabled", [False, True])
29+
def test_trainable_fn_utils(tmp_path, monkeypatch, v2_enabled):
30+
monkeypatch.setenv("RAY_TRAIN_V2_ENABLED", str(int(v2_enabled)))
31+
importlib.reload(ray.train)
32+
2033
dummy_checkpoint_dir = tmp_path.joinpath("dummy")
2134
dummy_checkpoint_dir.mkdir()
2235

36+
asserting_context = (
37+
functools.partial(pytest.raises, DeprecationWarning)
38+
if v2_enabled
39+
else functools.partial(pytest.warns, RayDeprecationWarning)
40+
)
41+
2342
def tune_fn(config):
24-
with pytest.warns(RayDeprecationWarning, match="ray.tune.get_checkpoint"):
43+
with asserting_context(match="ray.tune.get_checkpoint"):
2544
ray.train.get_checkpoint()
2645

2746
with warnings.catch_warnings():
2847
ray.tune.get_checkpoint()
2948

30-
with pytest.warns(RayDeprecationWarning, match="ray.tune.get_context"):
49+
with asserting_context(match="ray.tune.get_context"):
3150
ray.train.get_context()
3251

3352
with warnings.catch_warnings():
3453
ray.tune.get_context()
3554

36-
with pytest.warns(RayDeprecationWarning, match="ray.tune.report"):
55+
with asserting_context(match="ray.tune.report"):
3756
ray.train.report({"a": 1})
3857

3958
with warnings.catch_warnings():

0 commit comments

Comments
 (0)