Skip to content
47 changes: 47 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from unittest.mock import patch

import huggingface_hub.utils
import requests
import urllib3
from huggingface_hub import delete_repo
from packaging import version
Expand Down Expand Up @@ -200,6 +201,8 @@
IS_ROCM_SYSTEM = False
IS_CUDA_SYSTEM = False

logger = transformers_logging.get_logger(__name__)


def parse_flag_from_env(key, default=False):
try:
Expand Down Expand Up @@ -2497,11 +2500,55 @@ def wrapper(*args, **kwargs):
return test_func_ref(*args, **kwargs)

except Exception as err:
logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.")
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1

return test_func_ref(*args, **kwargs)

return wrapper

return decorator


def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2):
"""
To decorate tests that download from the Hub. They can fail due to a
variety of network issues such as timeouts, connection resets, etc.

Args:
max_attempts (`int`, *optional*, defaults to 5):
The maximum number of attempts to retry the flaky test.
wait_before_retry (`float`, *optional*, defaults to 2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if it's better to use exponential backoff instead of a fixed waiting time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can see if fixed gets the job done before overengineering? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also @Wauplin may just put this in huggingface_hub directly so this might (hopefully) just be temporary

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe at some point yes but not short-term so please go ahead with this PR :) (cc @hanouticelina for viz')

If provided, will wait that number of seconds before retrying the test.
"""

def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def wrapper(*args, **kwargs):
retry_count = 1

while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)
# We catch all exceptions related to network issues from requests
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.ReadTimeout,
requests.exceptions.HTTPError,
requests.exceptions.RequestException,
) as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1

# Raise any other errors
except Exception:
raise

return test_func_ref(*args, **kwargs)

return wrapper
Expand Down
11 changes: 11 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
)
from transformers.testing_utils import (
CaptureLogger,
hub_retry,
is_flaky,
require_accelerate,
require_bitsandbytes,
Expand Down Expand Up @@ -214,6 +215,16 @@ class ModelTesterMixin:
_is_composite = False
model_split_percents = [0.5, 0.7, 0.9]

# Note: for all mixins that utilize the Hub in some way, we should ensure that
# they contain the `hub_retry` decorator in case of failures.
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
for attr_name in dir(cls):
if attr_name.startswith("test_"):
attr = getattr(cls, attr_name)
if callable(attr):
setattr(cls, attr_name, hub_retry(attr))

@property
def all_generative_model_classes(self):
return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate())
Expand Down
Loading