Skip to content

Commit 58edaa4

Browse files
authored
[train+tune] Local directory refactor (1/n): Write launcher state files (tuner.pkl, trainer.pkl) directly to storage (#43369)
This PR updates `Trainer`s and the `Tuner` to upload its state directly to `storage_path`, rather than dumping it in a local directory and relying on driver syncing to upload. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
1 parent 53fe3fc commit 58edaa4

File tree

10 files changed

+84
-125
lines changed

10 files changed

+84
-125
lines changed

doc/source/train/doc_code/key_concepts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def train_fn(config):
129129
result_path: str = result.path
130130
result_filesystem: pyarrow.fs.FileSystem = result.filesystem
131131

132-
print("Results location (fs, path) = ({result_filesystem}, {result_path})")
132+
print(f"Results location (fs, path) = ({result_filesystem}, {result_path})")
133133
# __result_path_end__
134134

135135

doc/source/train/doc_code/tuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868

6969
tuner = Tuner(
7070
trainable=trainer,
71-
run_config=RunConfig(name="test_tuner"),
71+
run_config=RunConfig(name="test_tuner_xgboost"),
7272
param_space=param_space,
7373
tune_config=tune.TuneConfig(
7474
mode="min", metric="train-logloss", num_samples=2, max_concurrent_trials=2

python/ray/air/tests/test_errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
- Assert how errors from the trainable/Trainer get propagated to the user.
1616
- Assert how errors from the Tune driver get propagated to the user.
1717
"""
18+
1819
import gc
1920
import threading
2021
import time
@@ -198,7 +199,7 @@ def test_driver_error_with_tuner(ray_start_4_cpus, error_on):
198199
tuner.fit()
199200

200201
# TODO(ml-team): Assert the cause error type once driver error propagation is fixed
201-
assert "_TestSpecificError" in str(exc_info.value.__cause__)
202+
assert "_TestSpecificError" in str(exc_info.value)
202203

203204

204205
@pytest.mark.parametrize("error_on", ["on_trial_result"])

python/ray/train/base_trainer.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
from ray.air.result import Result
2222
from ray.train import Checkpoint
2323
from ray.train._internal.session import _get_session
24-
from ray.train._internal.storage import _exists_at_fs_path, get_fs_and_path
24+
from ray.train._internal.storage import (
25+
StorageContext,
26+
_exists_at_fs_path,
27+
get_fs_and_path,
28+
)
2529
from ray.util import PublicAPI
2630
from ray.util.annotations import DeveloperAPI
2731

@@ -226,7 +230,9 @@ def __init__(
226230
self.scaling_config = (
227231
scaling_config if scaling_config is not None else ScalingConfig()
228232
)
229-
self.run_config = run_config if run_config is not None else RunConfig()
233+
self.run_config = (
234+
copy.copy(run_config) if run_config is not None else RunConfig()
235+
)
230236
self.metadata = metadata
231237
self.datasets = datasets if datasets is not None else {}
232238
self.starting_checkpoint = resume_from_checkpoint
@@ -569,11 +575,23 @@ def fit(self) -> Result:
569575
``self.as_trainable()``, or during the Tune execution loop.
570576
"""
571577
from ray.tune import ResumeConfig, TuneError
572-
from ray.tune.tuner import Tuner, TunerInternal
578+
from ray.tune.tuner import Tuner
573579

574580
trainable = self.as_trainable()
575581
param_space = self._extract_fields_for_tuner_param_space()
576582

583+
self.run_config.name = (
584+
self.run_config.name or StorageContext.get_experiment_dir_name(trainable)
585+
)
586+
# The storage context here is only used to access the resolved
587+
# storage fs and experiment path, in order to avoid duplicating that logic.
588+
# This is NOT the storage context object that gets passed to remote workers.
589+
storage = StorageContext(
590+
storage_path=self.run_config.storage_path,
591+
experiment_dir_name=self.run_config.name,
592+
storage_filesystem=self.run_config.storage_filesystem,
593+
)
594+
577595
if self._restore_path:
578596
tuner = Tuner.restore(
579597
path=self._restore_path,
@@ -594,16 +612,11 @@ def fit(self) -> Result:
594612
_entrypoint=AirEntrypoint.TRAINER,
595613
)
596614

597-
experiment_local_path, _ = TunerInternal.setup_create_experiment_checkpoint_dir(
598-
trainable, self.run_config
599-
)
600-
601-
experiment_local_path = Path(experiment_local_path)
602-
self._save(experiment_local_path)
615+
self._save(storage.storage_filesystem, storage.experiment_fs_path)
603616

604617
restore_msg = TrainingFailedError._RESTORE_MSG.format(
605618
trainer_cls_name=self.__class__.__name__,
606-
path=str(experiment_local_path),
619+
path=str(storage.experiment_fs_path),
607620
)
608621

609622
try:
@@ -627,7 +640,7 @@ def fit(self) -> Result:
627640
) from result.error
628641
return result
629642

630-
def _save(self, experiment_path: Union[str, Path]):
643+
def _save(self, fs: pyarrow.fs.FileSystem, experiment_path: str):
631644
"""Saves the current trainer's class along with the `param_dict` of
632645
parameters passed to this trainer's constructor.
633646
@@ -656,9 +669,9 @@ def raise_fn():
656669

657670
cls_and_param_dict = (self.__class__, param_dict)
658671

659-
experiment_path = Path(experiment_path)
660-
with open(experiment_path / _TRAINER_PKL, "wb") as fp:
661-
pickle.dump(cls_and_param_dict, fp)
672+
fs.create_dir(experiment_path)
673+
with fs.open_output_stream(Path(experiment_path, _TRAINER_PKL).as_posix()) as f:
674+
f.write(pickle.dumps(cls_and_param_dict))
662675

663676
def _extract_fields_for_tuner_param_space(self) -> Dict:
664677
"""Extracts fields to be included in `Tuner.param_space`.

python/ray/train/tests/test_trainer_restore.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
from typing import Dict, List
55

6+
import pyarrow.fs
67
import pytest
78

89
import ray
@@ -185,26 +186,24 @@ def test_gbdt_trainer_restore(ray_start_6_cpus, tmp_path, trainer_cls, monkeypat
185186
assert tmp_path / exp_name in Path(result.path).parents
186187

187188

189+
@pytest.mark.parametrize("name", [None, "restore_from_uri"])
188190
def test_restore_from_uri_s3(
189-
ray_start_4_cpus, tmp_path, monkeypatch, mock_s3_bucket_uri
191+
ray_start_4_cpus, tmp_path, monkeypatch, mock_s3_bucket_uri, name
190192
):
191193
"""Restoration from S3 should work."""
192-
monkeypatch.setenv("RAY_AIR_LOCAL_CACHE_DIR", str(tmp_path))
193194
trainer = DataParallelTrainer(
194195
train_loop_per_worker=lambda config: train.report({"score": 1}),
195196
scaling_config=ScalingConfig(num_workers=2),
196-
run_config=RunConfig(name="restore_from_uri", storage_path=mock_s3_bucket_uri),
197+
run_config=RunConfig(name=name, storage_path=mock_s3_bucket_uri),
197198
)
198-
trainer.fit()
199+
result = trainer.fit()
199200

200-
# Restore from local dir
201-
DataParallelTrainer.restore(str(tmp_path / "restore_from_uri"))
201+
if name is None:
202+
name = Path(result.path).parent.name
202203

203204
# Restore from S3
204-
assert DataParallelTrainer.can_restore(
205-
str(URI(mock_s3_bucket_uri) / "restore_from_uri")
206-
)
207-
DataParallelTrainer.restore(str(URI(mock_s3_bucket_uri) / "restore_from_uri"))
205+
assert DataParallelTrainer.can_restore(str(URI(mock_s3_bucket_uri) / name))
206+
DataParallelTrainer.restore(str(URI(mock_s3_bucket_uri) / name))
208207

209208

210209
def test_restore_with_datasets(ray_start_4_cpus, tmpdir):
@@ -220,7 +219,7 @@ def test_restore_with_datasets(ray_start_4_cpus, tmpdir):
220219
scaling_config=ScalingConfig(num_workers=2),
221220
run_config=RunConfig(name="datasets_respecify_test", local_dir=tmpdir),
222221
)
223-
trainer._save(tmpdir)
222+
trainer._save(pyarrow.fs.LocalFileSystem(), str(tmpdir))
224223

225224
# Restore should complain, if all the datasets don't get passed in again
226225
with pytest.raises(ValueError):
@@ -246,7 +245,7 @@ def test_restore_with_different_trainer(tmpdir):
246245
scaling_config=ScalingConfig(num_workers=1),
247246
run_config=RunConfig(name="restore_with_diff_trainer"),
248247
)
249-
trainer._save(tmpdir)
248+
trainer._save(pyarrow.fs.LocalFileSystem(), str(tmpdir))
250249

251250
def attempt_restore(trainer_cls, should_warn: bool, should_raise: bool):
252251
def check_for_raise():
@@ -299,7 +298,7 @@ def test_trainer_can_restore_utility(tmp_path):
299298
scaling_config=ScalingConfig(num_workers=1),
300299
)
301300
(tmp_path / name).mkdir(exist_ok=True)
302-
trainer._save(tmp_path / name)
301+
trainer._save(pyarrow.fs.LocalFileSystem(), str(tmp_path / name))
303302

304303
assert DataParallelTrainer.can_restore(path)
305304

python/ray/train/tests/test_tune.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def test_run_config_in_trainer_and_tuner(
284284
run_config=trainer_run_config,
285285
)
286286
with caplog.at_level(logging.INFO, logger="ray.tune.impl.tuner_internal"):
287-
tuner = Tuner(trainer, run_config=tuner_run_config)
287+
Tuner(trainer, run_config=tuner_run_config)
288288

289289
both_msg = (
290290
"`RunConfig` was passed to both the `Tuner` and the `DataParallelTrainer`"
@@ -302,7 +302,6 @@ def test_run_config_in_trainer_and_tuner(
302302
assert not (tmp_path / "trainer").exists()
303303
assert both_msg not in caplog.text
304304
else:
305-
assert tuner._local_tuner.get_run_config() == RunConfig()
306305
assert both_msg not in caplog.text
307306

308307

python/ray/tune/impl/tuner_internal.py

Lines changed: 24 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import io
3-
import os
43
import math
54
import logging
65
from pathlib import Path
@@ -27,7 +26,6 @@
2726
from ray.tune import Experiment, ExperimentAnalysis, ResumeConfig, TuneError
2827
from ray.tune.tune import _Config
2928
from ray.tune.registry import is_function_trainable
30-
from ray.tune.result import _get_defaults_results_dir
3129
from ray.tune.result_grid import ResultGrid
3230
from ray.tune.trainable import Trainable
3331
from ray.tune.tune import run
@@ -102,7 +100,7 @@ def __init__(
102100
)
103101

104102
self._tune_config = tune_config or TuneConfig()
105-
self._run_config = run_config or RunConfig()
103+
self._run_config = copy.copy(run_config) or RunConfig()
106104
self._entrypoint = _entrypoint
107105

108106
# Restore from Tuner checkpoint.
@@ -129,23 +127,27 @@ def __init__(
129127
self._resume_config = None
130128
self._is_restored = False
131129
self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {}
132-
(
133-
self._local_experiment_dir,
134-
self._experiment_dir_name,
135-
) = self.setup_create_experiment_checkpoint_dir(
136-
self.converted_trainable, self._run_config
137-
)
138130
self._experiment_analysis = None
139131

140-
# This needs to happen before `tune.run()` is kicked in.
141-
# This is because currently tune does not exit gracefully if
142-
# run in ray client mode - if crash happens, it just exits immediately
143-
# without allowing for checkpointing tuner and trainable.
144-
# Thus this has to happen before tune.run() so that we can have something
145-
# to restore from.
146-
experiment_checkpoint_path = Path(self._local_experiment_dir, _TUNER_PKL)
147-
with open(experiment_checkpoint_path, "wb") as fp:
148-
pickle.dump(self.__getstate__(), fp)
132+
self._run_config.name = (
133+
self._run_config.name
134+
or StorageContext.get_experiment_dir_name(self.converted_trainable)
135+
)
136+
# The storage context here is only used to access the resolved
137+
# storage fs and experiment path, in order to avoid duplicating that logic.
138+
# This is NOT the storage context object that gets passed to remote workers.
139+
storage = StorageContext(
140+
storage_path=self._run_config.storage_path,
141+
experiment_dir_name=self._run_config.name,
142+
storage_filesystem=self._run_config.storage_filesystem,
143+
)
144+
145+
fs = storage.storage_filesystem
146+
fs.create_dir(storage.experiment_fs_path)
147+
with fs.open_output_stream(
148+
Path(storage.experiment_fs_path, _TUNER_PKL).as_posix()
149+
) as f:
150+
f.write(pickle.dumps(self.__getstate__()))
149151

150152
def get_run_config(self) -> RunConfig:
151153
return self._run_config
@@ -349,20 +351,16 @@ def _restore_from_path_or_uri(
349351
# Ex: s3://bucket/exp_name -> s3://bucket, exp_name
350352
self._run_config.name = path_or_uri_obj.name
351353
self._run_config.storage_path = str(path_or_uri_obj.parent)
352-
353-
(
354-
self._local_experiment_dir,
355-
self._experiment_dir_name,
356-
) = self.setup_create_experiment_checkpoint_dir(
357-
self.converted_trainable, self._run_config
358-
)
354+
# Update the storage_filesystem with the one passed in on restoration, if any.
355+
self._run_config.storage_filesystem = storage_filesystem
359356

360357
# Load the experiment results at the point where it left off.
361358
try:
362359
self._experiment_analysis = ExperimentAnalysis(
363360
experiment_checkpoint_path=path_or_uri,
364361
default_metric=self._tune_config.metric,
365362
default_mode=self._tune_config.mode,
363+
storage_filesystem=storage_filesystem,
366364
)
367365
except Exception:
368366
self._experiment_analysis = None
@@ -426,35 +424,6 @@ def _process_scaling_config(self) -> None:
426424
return
427425
self._param_space["scaling_config"] = scaling_config.__dict__.copy()
428426

429-
@classmethod
430-
def setup_create_experiment_checkpoint_dir(
431-
cls, trainable: TrainableType, run_config: Optional[RunConfig]
432-
) -> Tuple[str, str]:
433-
"""Sets up and creates the local experiment checkpoint dir.
434-
This is so that the `tuner.pkl` file gets stored in the same directory
435-
and gets synced with other experiment results.
436-
437-
Returns:
438-
Tuple: (experiment_path, experiment_dir_name)
439-
"""
440-
# TODO(justinvyu): Move this logic into StorageContext somehow
441-
experiment_dir_name = run_config.name or StorageContext.get_experiment_dir_name(
442-
trainable
443-
)
444-
storage_local_path = _get_defaults_results_dir()
445-
experiment_path = (
446-
Path(storage_local_path).joinpath(experiment_dir_name).as_posix()
447-
)
448-
449-
os.makedirs(experiment_path, exist_ok=True)
450-
return experiment_path, experiment_dir_name
451-
452-
# This has to be done through a function signature (@property won't do).
453-
def get_experiment_checkpoint_dir(self) -> str:
454-
# TODO(justinvyu): This is used to populate an error message.
455-
# This should point to the storage path experiment dir instead.
456-
return self._local_experiment_dir
457-
458427
@property
459428
def trainable(self) -> TrainableTypeOrTrainer:
460429
return self._trainable
@@ -583,7 +552,7 @@ def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]:
583552
return dict(
584553
storage_path=self._run_config.storage_path,
585554
storage_filesystem=self._run_config.storage_filesystem,
586-
name=self._experiment_dir_name,
555+
name=self._run_config.name,
587556
mode=self._tune_config.mode,
588557
metric=self._tune_config.metric,
589558
callbacks=self._run_config.callbacks,

python/ray/tune/tests/test_tuner_restore.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def create_trainable_with_params():
713713
)
714714
return trainable_with_params
715715

716-
exp_name = "restore_with_params"
716+
exp_name = f"restore_with_params-{use_function_trainable=}"
717717
fail_marker = tmp_path / "fail_marker"
718718
fail_marker.write_text("", encoding="utf-8")
719719

@@ -943,11 +943,13 @@ def get_checkpoints(experiment_dir):
943943
else:
944944
raise ValueError(f"Invalid trainable type: {trainable_type}")
945945

946+
exp_name = f"{trainable_type=}"
947+
946948
tuner = Tuner(
947949
trainable,
948950
tune_config=TuneConfig(num_samples=1),
949951
run_config=RunConfig(
950-
name="exp_name",
952+
name=exp_name,
951953
storage_path=str(tmp_path),
952954
checkpoint_config=checkpoint_config,
953955
),
@@ -966,7 +968,7 @@ def get_checkpoints(experiment_dir):
966968

967969
fail_marker.unlink()
968970
tuner = Tuner.restore(
969-
str(tmp_path / "exp_name"), trainable=trainable, resume_errored=True
971+
str(tmp_path / exp_name), trainable=trainable, resume_errored=True
970972
)
971973
results = tuner.fit()
972974

python/ray/tune/tune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,9 +512,9 @@ def run(
512512

513513
if _entrypoint == AirEntrypoint.TRAINER:
514514
error_message_map = {
515-
"entrypoint": "Trainer(...)",
515+
"entrypoint": "<FrameworkTrainer>(...)",
516516
"search_space_arg": "param_space",
517-
"restore_entrypoint": 'Trainer.restore(path="{path}", ...)',
517+
"restore_entrypoint": '<FrameworkTrainer>.restore(path="{path}", ...)',
518518
}
519519
elif _entrypoint == AirEntrypoint.TUNER:
520520
error_message_map = {

0 commit comments

Comments
 (0)