11import copy
22import io
3- import os
43import math
54import logging
65from pathlib import Path
2726from ray .tune import Experiment , ExperimentAnalysis , ResumeConfig , TuneError
2827from ray .tune .tune import _Config
2928from ray .tune .registry import is_function_trainable
30- from ray .tune .result import _get_defaults_results_dir
3129from ray .tune .result_grid import ResultGrid
3230from ray .tune .trainable import Trainable
3331from ray .tune .tune import run
@@ -102,7 +100,7 @@ def __init__(
102100 )
103101
104102 self ._tune_config = tune_config or TuneConfig ()
105- self ._run_config = run_config or RunConfig ()
103+ self ._run_config = copy . copy ( run_config ) or RunConfig ()
106104 self ._entrypoint = _entrypoint
107105
108106 # Restore from Tuner checkpoint.
@@ -129,23 +127,27 @@ def __init__(
129127 self ._resume_config = None
130128 self ._is_restored = False
131129 self ._tuner_kwargs = copy .deepcopy (_tuner_kwargs ) or {}
132- (
133- self ._local_experiment_dir ,
134- self ._experiment_dir_name ,
135- ) = self .setup_create_experiment_checkpoint_dir (
136- self .converted_trainable , self ._run_config
137- )
138130 self ._experiment_analysis = None
139131
140- # This needs to happen before `tune.run()` is kicked in.
141- # This is because currently tune does not exit gracefully if
142- # run in ray client mode - if crash happens, it just exits immediately
143- # without allowing for checkpointing tuner and trainable.
144- # Thus this has to happen before tune.run() so that we can have something
145- # to restore from.
146- experiment_checkpoint_path = Path (self ._local_experiment_dir , _TUNER_PKL )
147- with open (experiment_checkpoint_path , "wb" ) as fp :
148- pickle .dump (self .__getstate__ (), fp )
132+ self ._run_config .name = (
133+ self ._run_config .name
134+ or StorageContext .get_experiment_dir_name (self .converted_trainable )
135+ )
136+ # The storage context here is only used to access the resolved
137+ # storage fs and experiment path, in order to avoid duplicating that logic.
138+ # This is NOT the storage context object that gets passed to remote workers.
139+ storage = StorageContext (
140+ storage_path = self ._run_config .storage_path ,
141+ experiment_dir_name = self ._run_config .name ,
142+ storage_filesystem = self ._run_config .storage_filesystem ,
143+ )
144+
145+ fs = storage .storage_filesystem
146+ fs .create_dir (storage .experiment_fs_path )
147+ with fs .open_output_stream (
148+ Path (storage .experiment_fs_path , _TUNER_PKL ).as_posix ()
149+ ) as f :
150+ f .write (pickle .dumps (self .__getstate__ ()))
149151
150152 def get_run_config (self ) -> RunConfig :
151153 return self ._run_config
@@ -349,20 +351,16 @@ def _restore_from_path_or_uri(
349351 # Ex: s3://bucket/exp_name -> s3://bucket, exp_name
350352 self ._run_config .name = path_or_uri_obj .name
351353 self ._run_config .storage_path = str (path_or_uri_obj .parent )
352-
353- (
354- self ._local_experiment_dir ,
355- self ._experiment_dir_name ,
356- ) = self .setup_create_experiment_checkpoint_dir (
357- self .converted_trainable , self ._run_config
358- )
354+ # Update the storage_filesystem with the one passed in on restoration, if any.
355+ self ._run_config .storage_filesystem = storage_filesystem
359356
360357 # Load the experiment results at the point where it left off.
361358 try :
362359 self ._experiment_analysis = ExperimentAnalysis (
363360 experiment_checkpoint_path = path_or_uri ,
364361 default_metric = self ._tune_config .metric ,
365362 default_mode = self ._tune_config .mode ,
363+ storage_filesystem = storage_filesystem ,
366364 )
367365 except Exception :
368366 self ._experiment_analysis = None
@@ -426,35 +424,6 @@ def _process_scaling_config(self) -> None:
426424 return
427425 self ._param_space ["scaling_config" ] = scaling_config .__dict__ .copy ()
428426
429- @classmethod
430- def setup_create_experiment_checkpoint_dir (
431- cls , trainable : TrainableType , run_config : Optional [RunConfig ]
432- ) -> Tuple [str , str ]:
433- """Sets up and creates the local experiment checkpoint dir.
434- This is so that the `tuner.pkl` file gets stored in the same directory
435- and gets synced with other experiment results.
436-
437- Returns:
438- Tuple: (experiment_path, experiment_dir_name)
439- """
440- # TODO(justinvyu): Move this logic into StorageContext somehow
441- experiment_dir_name = run_config .name or StorageContext .get_experiment_dir_name (
442- trainable
443- )
444- storage_local_path = _get_defaults_results_dir ()
445- experiment_path = (
446- Path (storage_local_path ).joinpath (experiment_dir_name ).as_posix ()
447- )
448-
449- os .makedirs (experiment_path , exist_ok = True )
450- return experiment_path , experiment_dir_name
451-
452- # This has to be done through a function signature (@property won't do).
453- def get_experiment_checkpoint_dir (self ) -> str :
454- # TODO(justinvyu): This is used to populate an error message.
455- # This should point to the storage path experiment dir instead.
456- return self ._local_experiment_dir
457-
458427 @property
459428 def trainable (self ) -> TrainableTypeOrTrainer :
460429 return self ._trainable
@@ -583,7 +552,7 @@ def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]:
583552 return dict (
584553 storage_path = self ._run_config .storage_path ,
585554 storage_filesystem = self ._run_config .storage_filesystem ,
586- name = self ._experiment_dir_name ,
555+ name = self ._run_config . name ,
587556 mode = self ._tune_config .mode ,
588557 metric = self ._tune_config .metric ,
589558 callbacks = self ._run_config .callbacks ,
0 commit comments