Skip to content
Merged
3 changes: 0 additions & 3 deletions ci/lint/pydoclint-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1241,9 +1241,6 @@ python/ray/data/_internal/util.py
DOC402: Function `make_async_gen` has "yield" statements, but the docstring does not have a "Yields" section
DOC404: Function `make_async_gen` yield type(s) in docstring not consistent with the return annotation. Return annotation exists, but docstring "yields" section does not exist or has 0 type(s).
DOC103: Method `RetryingPyFileSystemHandler.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [retryable_errors: List[str]]. Arguments in the docstring but not in the function signature: [context: ].
DOC104: Function `call_with_retry`: Arguments are the same in the docstring and the function signature, but are in a different order.
DOC105: Function `call_with_retry`: Argument names match, but type hints in these args do not match: f, description, match, max_attempts, max_backoff_s
DOC201: Function `call_with_retry` does not have a return section in docstring
DOC104: Function `iterate_with_retry`: Arguments are the same in the docstring and the function signature, but are in a different order.
DOC105: Function `iterate_with_retry`: Argument names match, but type hints in these args do not match: iterable_factory, description, match, max_attempts, max_backoff_s
DOC001: Method `__init__` Potential formatting errors in docstring. Error message: No specification for "Args": ""
Expand Down
82 changes: 82 additions & 0 deletions python/ray/_common/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import functools
import logging
import random
import time
from typing import Any, Callable, List, Optional

logger = logging.getLogger(__name__)


def call_with_retry(
f: Callable,
description: str,
match: Optional[List[str]] = None,
max_attempts: int = 10,
max_backoff_s: int = 32,
*args,
**kwargs,
) -> Any:
"""Retry a function with exponential backoff.

Args:
f: The function to retry.
description: An imperative description of the function being retried. For
example, "open the file".
match: A list of strings to match in the exception message. If ``None``, any
error is retried.
max_attempts: The maximum number of attempts to retry.
max_backoff_s: The maximum number of seconds to backoff.
*args: Arguments to pass to the function.
**kwargs: Keyword arguments to pass to the function.

Returns:
The result of the function.
"""
# TODO: consider inverse match and matching exception type
assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."

for i in range(max_attempts):
try:
return f(*args, **kwargs)
except Exception as e:
exception_str = str(e)
is_retryable = match is None or any(
pattern in exception_str for pattern in match
)
if is_retryable and i + 1 < max_attempts:
# Retry with binary exponential backoff with 20% random jitter.
backoff = min(2**i, max_backoff_s) * (random.uniform(0.8, 1.2))
logger.debug(
f"Retrying {i+1} attempts to {description} after {backoff} seconds."
)
time.sleep(backoff)
else:
if is_retryable:
logger.debug(
f"Failed to {description} after {max_attempts} attempts. Raising."
)
else:
logger.debug(
f"Did not find a match for {exception_str}. Raising after {i+1} attempts."
)
raise e from None


def retry(
description: str,
match: Optional[List[str]] = None,
max_attempts: int = 10,
max_backoff_s: int = 32,
) -> Callable:
"""Decorator-based version of call_with_retry."""

def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def inner(*args, **kwargs):
return call_with_retry(
func, description, match, max_attempts, max_backoff_s, *args, **kwargs
)

return inner

return decorator
1 change: 1 addition & 0 deletions python/ray/_common/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ py_test_module_list(
"test_formatters.py",
"test_network_utils.py",
"test_ray_option_utils.py",
"test_retry.py",
"test_signal_semaphore_utils.py",
"test_signature.py",
"test_utils.py",
Expand Down
95 changes: 95 additions & 0 deletions python/ray/_common/tests/test_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import sys

import pytest

from ray._common.retry import (
call_with_retry,
retry,
)


def test_call_with_retry_immediate_success_with_args():
def func(a, b):
return [a, b]

assert call_with_retry(func, "func", [], 1, 0, "a", "b") == ["a", "b"]


def test_retry_immediate_success_with_object_args():
class MyClass:
@retry("func", [], 1, 0)
def func(self, a, b):
return [a, b]

assert MyClass().func("a", "b") == ["a", "b"]


@pytest.mark.parametrize("use_decorator", [True, False])
def test_retry_last_attempt_successful_with_appropriate_wait_time(
monkeypatch, use_decorator
):
sleep_total = 0

def sleep(x):
nonlocal sleep_total
sleep_total += x

monkeypatch.setattr("time.sleep", sleep)
monkeypatch.setattr("random.uniform", lambda a, b: 1)

pattern = "have not reached 4th attempt"
call_count = 0

def func():
nonlocal call_count
call_count += 1
if call_count == 4:
return "success"
raise ValueError(pattern)

args = ["func", [pattern], 4, 3]
if use_decorator:
assert retry(*args)(func)() == "success"
else:
assert call_with_retry(func, *args) == "success"
assert sleep_total == 6 # 1 + 2 + 3


@pytest.mark.parametrize("use_decorator", [True, False])
def test_retry_unretryable_error(use_decorator):
call_count = 0

def func():
nonlocal call_count
call_count += 1
raise ValueError("unretryable error")

args = ["func", ["only retryable error"], 10, 0]
with pytest.raises(ValueError, match="unretryable error"):
if use_decorator:
retry(*args)(func)()
else:
call_with_retry(func, *args)
assert call_count == 1


@pytest.mark.parametrize("use_decorator", [True, False])
def test_retry_fail_all_attempts_retry_all_errors(use_decorator):
call_count = 0

def func():
nonlocal call_count
call_count += 1
raise ValueError(str(call_count))

args = ["func", None, 3, 0]
with pytest.raises(ValueError):
if use_decorator:
retry(*args)(func)()
else:
call_with_retry(func, *args)
assert call_count == 3


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
6 changes: 2 additions & 4 deletions python/ray/data/_internal/datasource/lance_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

import numpy as np

from ray.data._internal.util import (
_check_import,
call_with_retry,
)
from ray._common.retry import call_with_retry
from ray.data._internal.util import _check_import
from ray.data.block import BlockMetadata
from ray.data.context import DataContext
from ray.data.datasource.datasource import Datasource, ReadTask
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/datasource/parquet_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional

from ray._common.retry import call_with_retry
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.planner.plan_write_op import WRITE_UUID_KWARG_NAME
from ray.data._internal.savemode import SaveMode
from ray.data._internal.util import call_with_retry
from ray.data.block import Block, BlockAccessor
from ray.data.datasource.file_based_datasource import _resolve_kwargs
from ray.data.datasource.file_datasink import _FileDatasink
Expand Down
41 changes: 1 addition & 40 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from packaging.version import parse as parse_version

import ray
from ray._common.retry import call_with_retry
from ray._private.arrow_utils import get_pyarrow_version
from ray.data.context import DEFAULT_READ_OP_MIN_NUM_BLOCKS, WARN_PREFIX, DataContext

Expand Down Expand Up @@ -1415,46 +1416,6 @@ def open_input_file(self, path: str) -> "pyarrow.NativeFile":
)


def call_with_retry(
f: Callable[[], Any],
description: str,
*,
match: Optional[List[str]] = None,
max_attempts: int = 10,
max_backoff_s: int = 32,
) -> Any:
"""Retry a function with exponential backoff.

Args:
f: The function to retry.
match: A list of strings to match in the exception message. If ``None``, any
error is retried.
description: An imperitive description of the function being retried. For
example, "open the file".
max_attempts: The maximum number of attempts to retry.
max_backoff_s: The maximum number of seconds to backoff.
"""
assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."

for i in range(max_attempts):
try:
return f()
except Exception as e:
is_retryable = match is None or any(pattern in str(e) for pattern in match)
if is_retryable and i + 1 < max_attempts:
# Retry with binary expoential backoff with random jitter.
backoff = min((2 ** (i + 1)), max_backoff_s) * (random.random())
logger.debug(
f"Retrying {i+1} attempts to {description} after {backoff} seconds."
)
time.sleep(backoff)
else:
logger.debug(
f"Did not find a match for {str(e)}. Raising after {i+1} attempts."
)
raise e from None


def iterate_with_retry(
iterable_factory: Callable[[], Iterable],
description: str,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/datasource/file_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
from urllib.parse import urlparse

from ray._common.retry import call_with_retry
from ray._private.arrow_utils import add_creatable_buckets_param_if_s3_uri
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
Expand All @@ -11,7 +12,6 @@
from ray.data._internal.util import (
RetryingPyFileSystem,
_is_local_scheme,
call_with_retry,
)
from ray.data.block import Block, BlockAccessor
from ray.data.context import DataContext
Expand Down
7 changes: 7 additions & 0 deletions python/ray/train/v2/_internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
# The name of the file that is used to store the checkpoint manager snapshot.
CHECKPOINT_MANAGER_SNAPSHOT_FILENAME = "checkpoint_manager_snapshot.json"

AWS_RETRYABLE_TOKENS = (
"AWS Error SLOW_DOWN",
"AWS Error INTERNAL_FAILURE",
"AWS Error SERVICE_UNAVAILABLE",
"AWS Error NETWORK_CONNECTION",
"AWS Error UNKNOWN",
)

# -----------------------------------------------------------------------
# Environment variables used in the controller, workers, and state actor.
Expand Down
25 changes: 16 additions & 9 deletions python/ray/train/v2/_internal/execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import ray
from ray._common.retry import retry
from ray.actor import ActorHandle
from ray.data import DataIterator, Dataset
from ray.train._internal import session
from ray.train._internal.session import _TrainingResult
from ray.train.v2._internal.constants import AWS_RETRYABLE_TOKENS
from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor
from ray.train.v2._internal.execution.storage import StorageContext, delete_fs_path
from ray.train.v2._internal.util import (
Expand Down Expand Up @@ -215,6 +217,8 @@ def _sync_checkpoint_dir_name_across_ranks(
)
)

# TODO: make retry configurable
@retry(description="upload checkpoint", max_attempts=3, match=AWS_RETRYABLE_TOKENS)
def _upload_checkpoint(
self,
checkpoint_dir_name: str,
Expand Down Expand Up @@ -334,10 +338,10 @@ def report(
# Upload checkpoint, wait for turn, and report.
if checkpoint_upload_mode == CheckpointUploadMode.SYNC:
training_result = self._upload_checkpoint(
checkpoint_dir_name,
metrics,
checkpoint,
delete_local_checkpoint_after_upload,
checkpoint_dir_name=checkpoint_dir_name,
metrics=metrics,
checkpoint=checkpoint,
delete_local_checkpoint_after_upload=delete_local_checkpoint_after_upload,
)
self._wait_then_report(training_result, report_call_index)

Expand All @@ -357,15 +361,18 @@ def _upload_checkpoint_and_report(
) -> None:
try:
training_result = self._upload_checkpoint(
checkpoint_dir_name,
metrics,
checkpoint,
delete_local_checkpoint_after_upload,
checkpoint_dir_name=checkpoint_dir_name,
metrics=metrics,
checkpoint=checkpoint,
delete_local_checkpoint_after_upload=delete_local_checkpoint_after_upload,
)
self._wait_then_report(training_result, report_call_index)
except Exception as e:
# TODO: env var to disable eager raising
logger.exception(
"Async checkpoint upload failed - shutting down workers"
"Checkpoint upload failed in the background thread. Raising eagerly "
"to avoid training in a corrupted state with more potential progress "
"lost due to checkpointing failures."
)
self.execution_context.training_thread_runner.get_exception_queue().put(
construct_user_exception_with_traceback(e)
Expand Down