[Train] Skip incrementing failure counter on preemption node died failures#41285
Conversation
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…_recover logic accordingly Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
| trial.set_location(_Location()) | ||
|
|
||
| if exception: | ||
| trial.handle_error(exc=exception) |
There was a problem hiding this comment.
Key change 1: The reason for moving this trial.handle_error is:
handle_erroris what increments the number of failures.- Upon a trial failure, we used to check
trial.should_recoverBEFORE incrementing the new failure, sotrial.num_failuresis 1 less than it should be at that check. (Let's saynum_failures=2, max_failures=3at this point.) handle_errorwould happen afterwards right here. (num_failures=3now.)- The old
should_recovercondition made it impossible for us to try recovering on a preemption error. (Even thoughhandle_errorwould noop for the preemption error, we're already atnum_failures==max_failures=3, soshould_recover=False.) - It's more intuitive to have
num_failuresupdated by the time oftrial.should_recover, so now we just handle the error separately.
There was a problem hiding this comment.
Seems like we need to call trial.handle_error for all the other places that are currently calling _schedule_trial_stop?
There was a problem hiding this comment.
In general though I think this does move us in the right direction. In the long-term we should clean up the state machine in the tune controller, in it's current state it's not really clear where error handling is supposed to take place. 😵💫
There was a problem hiding this comment.
Yep, I call it in the other 2 other places.
| def _handle_ray_actor_error(self, exc: RayActorError): | ||
| exc._preempted = True # TODO(justinvyu): Test the real integration | ||
| if not exc._preempted: | ||
| # Only count non-preempted actor errors as failures. | ||
| self.run_metadata.num_failures += 1 | ||
|
|
||
| def _handle_ray_task_error(self, exc: RayTaskError): | ||
| if isinstance(exc.cause, RayActorError): | ||
| # Handle the RayActorError directly (ex: Ray Train worker actor errors) | ||
| return self._handle_ray_actor_error(exc.cause) | ||
|
|
||
| # Increment failures for all user errors (which get raised as RayTaskError) | ||
| self.run_metadata.num_failures += 1 |
There was a problem hiding this comment.
Key change 2: this is the actual logic of the PR.
Question: Is it ok to treat RayTaskError with a cause of RayActorError so broadly like this? One strawman counterexample:
def tune_fn_trainable(config):
e = RayActorError()
e._preempted = True
raise e
tune.Tuner(tune_fn_trainable).fit()There was a problem hiding this comment.
Another possibility would be to have the DataParallelTrainer pass through the pre-emption RayActorError as a special case, but I feel like that's more misleading, as it's disguising the coordinator's error with the worker's error.
There was a problem hiding this comment.
TODO: use exc.as_instanceof_cause() instead of the private cause attr once that is fixed by @rkooo567
There was a problem hiding this comment.
This seems okay for now... this seems clean enough for now such that if new use cases come up in the future we can separate this logic and improve it further.
There was a problem hiding this comment.
TODO: add the configurability of whether or not to count preemption errors here.
| `num_failures` should represent the number of times the trial has | ||
| failed *up to the moment this method is called.* If we've failed | ||
| 5 times and `max_failures=5`, then we should recover, since | ||
| we only pass the limit on the 6th failure. | ||
|
|
||
| Note this may return true even when there is no checkpoint, either because | ||
| `self.checkpoint_freq` is `0` or because the trial failed before | ||
| a checkpoint has been made. | ||
| """ | ||
| return ( | ||
| self.run_metadata.num_failures < self.max_failures | ||
| or self.max_failures < 0 | ||
| or ( | ||
| self.run_metadata.num_failures == self.max_failures | ||
| and self.temporary_state.num_restore_failures | ||
| < int(os.environ.get("TUNE_RESTORE_RETRY_NUM", 0)) | ||
| ) | ||
| self.run_metadata.num_failures <= self.max_failures or self.max_failures < 0 |
There was a problem hiding this comment.
See key change 1 comment.
| self.run_metadata.num_failures == self.max_failures | ||
| and self.temporary_state.num_restore_failures | ||
| < int(os.environ.get("TUNE_RESTORE_RETRY_NUM", 0)) | ||
| ) |
There was a problem hiding this comment.
I believe this condition is not needed anymore.
TUNE_RESTORE_RETRY_NUM configures how many attempts we try to restore before it counts as a real error.
The behavior with this condition removed makes sense to me:
If I'm at num_failures==max_failures, then I should try up to TUNE_RESTORE_RETRY_NUM times to restore. If all of those attempts fail, then we'll increment so that num_failures > max_failures, and the run will not try to recover anymore.
python/ray/air/tests/test_errors.py
Outdated
| - Round 0: Actor error in the training worker. (shouldn't be counted) | ||
| - Round 1: User error in the training worker. | ||
| - Round 2: Actor error in the coordinator actor. (shouldn't be counted) | ||
| - Round 3: No error. |
There was a problem hiding this comment.
Should we just run this as 4 separate jobs and check each one if it failed/counted?
There was a problem hiding this comment.
I am not able to figure out how to mock a property on the RayActorError that core raises -- any ideas here?
I tried this:
class MockRayActorError(ray.exceptions.RayActorError):
preempted = True
monkeypatch.setattr(
ray.tune.execution.tune_controller, "RayActorError", MockRayActorError
)
monkeypatch.setattr(ray.exceptions, "RayActorError", MockRayActorError)I was planning on reworking this test to use the actual gcs_client.drain_node API to mock the preemption instead of mocking the attribute. (example here)
python/ray/tune/experiment/trial.py
Outdated
| def _handle_restore_error(self, exc: _TuneRestoreError): | ||
| exc = exc.exc | ||
| if self.temporary_state.num_restore_failures >= int( | ||
| os.environ.get("TUNE_RESTORE_RETRY_NUM", 0) | ||
| ): | ||
| # Restore was unsuccessful, try again without checkpoint. | ||
| self.clear_checkpoint() | ||
| self.run_metadata.num_failures += 1 | ||
| else: | ||
| self.temporary_state.num_restore_failures += 1 |
There was a problem hiding this comment.
Orthogonal to this change but I'm wondering if we even want to keep this logic... not really clear to me why we remove the checkpoint and increase the number of failures.
There was a problem hiding this comment.
This handle_restore_error happens when the call to Trainable.restore fails:
- This may be caused by a checkpoint download from cloud failing. Retrying without adding to the total failures counter may help here.
- There may be a bug in a user's
load_checkpointcode. Retrying wouldn't help here. - Function trainables don't do any logic in
restore/load_checkpoint, leaving it to the user instead -- so this only really applies to class trainables.
There was a problem hiding this comment.
The default behavior is a little strange though: TUNE_RESTORE_RETRY_NUM=0 --> failures during restore clear the checkpoint count toward num_failures and the run starts from scratch immediately.
If we remove this logic, the behavior becomes: failure during restore are treated normally and keep retrying from the checkpoint until max_failures. I think it makes sense to remove this and restoring_from so that we have to keep track of less state in total. Let's do that in a separate PR.
There was a problem hiding this comment.
It seems the function here treats the restoration error differently as normal training error. I.e., there is a separate counter on restoration error that a consecutive of TUNE_RESTORE_RETRY_NUM restoration error will count as one num_failures.
I think it still makes sense if we keep this logic here. However, it doesn't makes much sense for us to clear or modify the latest checkpoint content. The clear_checkpoint function is mainly prepared for the cases that a corrupted checkpoint leading to restoration error. I agree it might be one reason of the problem, but not generally the only reason for a restoration failure. I think the easiest way for us to fix here is to remove the clear_checkpoint function . We can still keep the handle_restore_error function as a special case of all errors.
I.e., TUNE_RESTORE_RETRY_NUM restoration failures contributes to one num_failures. But we don't pre-assume or add special handling to fix the restoration error. It might be more likely due to a node preemption, that we don't need special handling, just by chance it may fail/ success. Add a few more retries can already help. We should not clear the latest checkpoint, which introduces extra complexity. In case of a corrupted latest checkpoint, we just let the job fail after TUNE_RESTORE_RETRY_NUM * TUNE_RESTORE_NUM
cc @justinvyu @matthewdeng, if it looks good, I can make a PR to fix this.
| trial.set_location(_Location()) | ||
|
|
||
| if exception: | ||
| trial.handle_error(exc=exception) |
There was a problem hiding this comment.
Seems like we need to call trial.handle_error for all the other places that are currently calling _schedule_trial_stop?
| trial.set_location(_Location()) | ||
|
|
||
| if exception: | ||
| trial.handle_error(exc=exception) |
There was a problem hiding this comment.
In general though I think this does move us in the right direction. In the long-term we should clean up the state machine in the tune controller, in it's current state it's not really clear where error handling is supposed to take place. 😵💫
| def _handle_ray_actor_error(self, exc: RayActorError): | ||
| exc._preempted = True # TODO(justinvyu): Test the real integration | ||
| if not exc._preempted: | ||
| # Only count non-preempted actor errors as failures. | ||
| self.run_metadata.num_failures += 1 | ||
|
|
||
| def _handle_ray_task_error(self, exc: RayTaskError): | ||
| if isinstance(exc.cause, RayActorError): | ||
| # Handle the RayActorError directly (ex: Ray Train worker actor errors) | ||
| return self._handle_ray_actor_error(exc.cause) | ||
|
|
||
| # Increment failures for all user errors (which get raised as RayTaskError) | ||
| self.run_metadata.num_failures += 1 |
There was a problem hiding this comment.
This seems okay for now... this seems clean enough for now such that if new use cases come up in the future we can separate this logic and improve it further.
| if self.local_path: | ||
| self.run_metadata.error_filename = EXPR_ERROR_FILE | ||
| if isinstance(exc, RayTaskError): | ||
| if isinstance(exc, (RayTaskError, RayActorError)): |
There was a problem hiding this comment.
Hmmm given that we've never logged these before, when does RayActorError actually happen? Would it be a RayActorError or RayTaskError if the trial node gets preempted?
There was a problem hiding this comment.
If the trial node gets preempted, it's a RayActorError.
ray.get(A.task.remote()) -> RayActorError if A's node dies
ray.get(A.task.remote()) -> RayTaskError(OriginalError) if A.task raises an OriginalError inside it.
There was a problem hiding this comment.
I think it was just an oversight not to log RayActorError
…le_spot_instance_failures
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
| trial.temporary_state.saving_to = None | ||
| if trial.is_restoring and exc: | ||
| exc = _TuneRestoreError(exc) | ||
| self._schedule_trial_stop(trial, exception=exc) |
There was a problem hiding this comment.
Do we need to call trial.handle_error(exception) before this one?
There was a problem hiding this comment.
try_recover only gets called in process_trial_failure, which already calls handle_error.
python/ray/air/tests/test_errors.py
Outdated
| - Round 0: Actor error in the training worker. (shouldn't be counted) | ||
| - Round 1: User error in the training worker. | ||
| - Round 2: Actor error in the coordinator actor. (shouldn't be counted) | ||
| - Round 3: No error. |
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
This reverts commit ba62f43. Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…le_spot_instance_failures
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…le_spot_instance_failures
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…le_spot_instance_failures
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…le_spot_instance_failures
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…d failures (ray-project#41285) Users expect different failures types to be handled differently in step 4 above: * The current behavior is that the count decrements, regardless of the error type. For example, if 3 pre-emptions happen with `max_failures=3`, then the run will end without continuing to recover through preemptions. * With `max_failures=-1` or some large value, there will be an infinite number of retries, but this could crash-loop on an application error (ex: a bug in the user code). This can be very expensive. This PR changes the failure counting of Ray Train/Tune to ignore spot instance preemption failures by default. This behavior is enabled by the new `RayActorError.preempted` flag introduced in ray-project#41102 that is set if the underlying cluster setup handles the cloud preemption signals properly and sets the preempting node to the `DRAINING` status. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
|
@justinvyu Nit:
|
…d failures (#41285) (#41609) Users expect different failures types to be handled differently in step 4 above: * The current behavior is that the count decrements, regardless of the error type. For example, if 3 pre-emptions happen with `max_failures=3`, then the run will end without continuing to recover through preemptions. * With `max_failures=-1` or some large value, there will be an infinite number of retries, but this could crash-loop on an application error (ex: a bug in the user code). This can be very expensive. This PR changes the failure counting of Ray Train/Tune to ignore spot instance preemption failures by default. This behavior is enabled by the new `RayActorError.preempted` flag introduced in #41102 that is set if the underlying cluster setup handles the cloud preemption signals properly and sets the preempting node to the `DRAINING` status. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
|
@justinvyu did we address @zhe-thoughts 's last comment above (it was made post merge and I didn't see any additional links in this ticket). |
|
Yeah, see the updated PR description. |
Why are these changes needed?
Users expect different failures types to be handled differently:
max_failures=3, then the run will end without continuing to recover through preemptions.max_failures=-1or some large value, there will be an infinite number of retries, but this could crash-loop on an application error (ex: a bug in the user code). This can be very expensive.This PR changes the failure counting of Ray Train/Tune to ignore spot instance preemption failures by default. This behavior is enabled by the new
RayActorError.preemptedflag introduced in #41102 that is set if the underlying cluster setup handles the cloud preemption signals properly and sets the preempting node to theDRAININGstatus.Example
Here is an example scenario:
train.report/torch.distributed), while D is preempted.RayActorError(preempted=True).preempted=True. This allows preemption failures to be retried repeatedly without contributing totrain.FailureConfig(max_failures=X).Xmax failures by setting the environment variable:RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE=1Miscellaneous
This is the current output in
error.txt. TODO: the numbering should be fixed, and some indication of ignored errors should be added in.Related issue number
Checks
git commit -s) in this PR.scripts/format.shto lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/under thecorresponding
.rstfile.