Skip to content

Commit ad0b3cc

Browse files
authored
Threadsafe RestZenStore (#3758)
* Threadsafe RestZenStore * Indentation * Switch to timestamp
1 parent 67e8af8 commit ad0b3cc

File tree

1 file changed

+146
-127
lines changed

1 file changed

+146
-127
lines changed

src/zenml/zen_stores/rest_zen_store.py

Lines changed: 146 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import time
1919
from datetime import datetime
2020
from pathlib import Path
21+
from threading import RLock
2122
from typing import (
2223
Any,
2324
ClassVar,
@@ -39,6 +40,7 @@
3940
BaseModel,
4041
ConfigDict,
4142
Field,
43+
PrivateAttr,
4244
ValidationError,
4345
field_validator,
4446
model_validator,
@@ -272,6 +274,7 @@
272274
replace_localhost_with_internal_hostname,
273275
)
274276
from zenml.utils.pydantic_utils import before_validator_handler
277+
from zenml.utils.time_utils import utc_now
275278
from zenml.zen_server.exceptions import exception_from_response
276279
from zenml.zen_stores.base_zen_store import BaseZenStore
277280

@@ -440,6 +443,8 @@ class RestZenStore(BaseZenStore):
440443
_api_token: Optional[APIToken] = None
441444
_session: Optional[requests.Session] = None
442445
_server_info: Optional[ServerModel] = None
446+
_session_lock: RLock = PrivateAttr(default_factory=RLock)
447+
_last_authenticated: Optional[datetime] = None
443448

444449
# ====================================
445450
# ZenML Store interface implementation
@@ -4203,74 +4208,79 @@ def session(self) -> requests.Session:
42034208
Returns:
42044209
A requests session.
42054210
"""
4206-
if self._session is None:
4207-
# We only need to initialize the session once over the lifetime
4208-
# of the client. We can swap the token out when it expires.
4209-
if self.config.verify_ssl is False:
4210-
urllib3.disable_warnings(
4211-
urllib3.exceptions.InsecureRequestWarning
4212-
)
4211+
with self._session_lock:
4212+
if self._session is None:
4213+
# We only need to initialize the session once over the lifetime
4214+
# of the client. We can swap the token out when it expires.
4215+
if self.config.verify_ssl is False:
4216+
urllib3.disable_warnings(
4217+
urllib3.exceptions.InsecureRequestWarning
4218+
)
42134219

4214-
self._session = requests.Session()
4215-
# Retries are triggered for all HTTP methods (GET, HEAD, POST, PUT,
4216-
# PATCH, OPTIONS and DELETE) on specific HTTP status codes:
4217-
#
4218-
# 408: Request Timeout.
4219-
# 429: Too Many Requests.
4220-
# 502: Bad Gateway.
4221-
# 503: Service Unavailable.
4222-
# 504: Gateway Timeout
4223-
#
4224-
# This also handles connection level errors, if a connection attempt
4225-
# fails due to transient issues like:
4226-
#
4227-
# DNS resolution errors.
4228-
# Connection timeouts.
4229-
# Network disruptions.
4230-
#
4231-
# Additional errors retried:
4232-
#
4233-
# Read Timeouts: If the server does not send a response within
4234-
# the timeout period.
4235-
# Connection Refused: If the server refuses the connection.
4236-
#
4237-
retries = Retry(
4238-
connect=5,
4239-
read=8,
4240-
redirect=3,
4241-
status=10,
4242-
allowed_methods=[
4243-
"HEAD",
4244-
"GET",
4245-
"POST",
4246-
"PUT",
4247-
"PATCH",
4248-
"DELETE",
4249-
"OPTIONS",
4250-
],
4251-
status_forcelist=[
4252-
408, # Request Timeout
4253-
429, # Too Many Requests
4254-
502, # Bad Gateway
4255-
503, # Service Unavailable
4256-
504, # Gateway Timeout
4257-
],
4258-
other=3,
4259-
backoff_factor=1,
4260-
)
4261-
self._session.mount("https://", HTTPAdapter(max_retries=retries))
4262-
self._session.mount("http://", HTTPAdapter(max_retries=retries))
4263-
self._session.verify = self.config.verify_ssl
4264-
# Use a custom user agent to identify the ZenML client in the server
4265-
# logs.
4266-
self._session.headers.update(
4267-
{"User-Agent": "zenml/" + zenml.__version__}
4268-
)
4220+
self._session = requests.Session()
4221+
# Retries are triggered for all HTTP methods (GET, HEAD, POST, PUT,
4222+
# PATCH, OPTIONS and DELETE) on specific HTTP status codes:
4223+
#
4224+
# 408: Request Timeout.
4225+
# 429: Too Many Requests.
4226+
# 502: Bad Gateway.
4227+
# 503: Service Unavailable.
4228+
# 504: Gateway Timeout
4229+
#
4230+
# This also handles connection level errors, if a connection attempt
4231+
# fails due to transient issues like:
4232+
#
4233+
# DNS resolution errors.
4234+
# Connection timeouts.
4235+
# Network disruptions.
4236+
#
4237+
# Additional errors retried:
4238+
#
4239+
# Read Timeouts: If the server does not send a response within
4240+
# the timeout period.
4241+
# Connection Refused: If the server refuses the connection.
4242+
#
4243+
retries = Retry(
4244+
connect=5,
4245+
read=8,
4246+
redirect=3,
4247+
status=10,
4248+
allowed_methods=[
4249+
"HEAD",
4250+
"GET",
4251+
"POST",
4252+
"PUT",
4253+
"PATCH",
4254+
"DELETE",
4255+
"OPTIONS",
4256+
],
4257+
status_forcelist=[
4258+
408, # Request Timeout
4259+
429, # Too Many Requests
4260+
502, # Bad Gateway
4261+
503, # Service Unavailable
4262+
504, # Gateway Timeout
4263+
],
4264+
other=3,
4265+
backoff_factor=1,
4266+
)
4267+
self._session.mount(
4268+
"https://", HTTPAdapter(max_retries=retries)
4269+
)
4270+
self._session.mount(
4271+
"http://", HTTPAdapter(max_retries=retries)
4272+
)
4273+
self._session.verify = self.config.verify_ssl
4274+
# Use a custom user agent to identify the ZenML client in the server
4275+
# logs.
4276+
self._session.headers.update(
4277+
{"User-Agent": "zenml/" + zenml.__version__}
4278+
)
42694279

4270-
# Note that we return an unauthenticated session here. An API token
4271-
# is only fetched and set in the authorization header when and if it is
4272-
# needed.
4273-
return self._session
4280+
# Note that we return an unauthenticated session here. An API token
4281+
# is only fetched and set in the authorization header when and if it is
4282+
# needed.
4283+
return self._session
42744284

42754285
def authenticate(self, force: bool = False) -> None:
42764286
"""Authenticate or re-authenticate to the ZenML server.
@@ -4325,6 +4335,7 @@ def authenticate(self, force: bool = False) -> None:
43254335
{"Authorization": "Bearer " + new_api_token}
43264336
)
43274337
logger.debug(f"Authenticated to {self.url}")
4338+
self._last_authenticated = utc_now()
43284339

43294340
@staticmethod
43304341
def _handle_response(response: requests.Response) -> Json:
@@ -4391,9 +4402,9 @@ def _request(
43914402
CredentialsNotValid: if the request fails due to invalid
43924403
client credentials.
43934404
"""
4394-
self.session.headers.update(
4395-
{source_context.name: source_context.get().value}
4396-
)
4405+
request_headers = {
4406+
source_context.name: source_context.get().value,
4407+
}
43974408

43984409
# If the server replies with a credentials validation (401 Unauthorized)
43994410
# error, we (re-)authenticate and retry the request here in the
@@ -4414,19 +4425,21 @@ def _request(
44144425
while True:
44154426
# Add a request ID to the request headers
44164427
request_id = str(uuid4())[:8]
4417-
self.session.headers.update({"X-Request-ID": request_id})
4428+
request_headers["X-Request-ID"] = request_id
44184429
# Add an idempotency key to the request headers to ensure that
44194430
# requests are idempotent.
4420-
self.session.headers.update({"Idempotency-Key": str(uuid4())})
4431+
request_headers["Idempotency-Key"] = str(uuid4())
44214432

44224433
start_time = time.time()
44234434
logger.debug(f"[{request_id}] {method} {path} started...")
44244435
status_code = "failed"
4436+
last_authenticated = self._last_authenticated
44254437

44264438
try:
44274439
response = self.session.request(
44284440
method,
44294441
url,
4442+
headers=request_headers,
44304443
params=params if params else {},
44314444
verify=self.config.verify_ssl,
44324445
timeout=timeout or self.config.http_timeout,
@@ -4442,62 +4455,68 @@ def _request(
44424455
# authenticated at all.
44434456
credentials_store = get_credentials_store()
44444457

4445-
if self._api_token is None:
4446-
# The last request was not authenticated with an API
4447-
# token at all. We authenticate here and then try the
4448-
# request again, this time with a valid API token in the
4449-
# header.
4450-
logger.debug(
4451-
f"[{request_id}] The last request was not "
4452-
f"authenticated: {e}\n"
4453-
"Re-authenticating and retrying..."
4454-
)
4455-
self.authenticate()
4456-
elif not credentials_store.can_login(self.url):
4457-
# The request failed either because we're not
4458-
# authenticated or our current credentials are not valid
4459-
# anymore.
4460-
logger.error(
4461-
"The current token is no longer valid, and "
4462-
"it is not possible to generate a new token using the "
4463-
"configured credentials. Please run "
4464-
f"`zenml login {self.url}` to "
4465-
"re-authenticate to the server or authenticate using "
4466-
"an API key. See "
4467-
"https://docs.zenml.io/deploying-zenml/connecting-to-zenml/connect-with-a-service-account "
4468-
"for more information."
4469-
)
4470-
# Clear the current token from the credentials store to
4471-
# force a new authentication flow next time.
4472-
get_credentials_store().clear_token(self.url)
4473-
raise e
4474-
elif not re_authenticated:
4475-
# The last request was authenticated with an API token
4476-
# that was rejected by the server. We attempt a
4477-
# re-authentication here and then retry the request.
4478-
logger.debug(
4479-
f"[{request_id}] The last request was authenticated "
4480-
"with an API token that was rejected by the server: "
4481-
f"{e}\n"
4482-
"Re-authenticating and retrying..."
4483-
)
4484-
re_authenticated = True
4485-
self.authenticate(
4486-
# Ignore the current token and force a re-authentication
4487-
force=True
4488-
)
4489-
else:
4490-
# The last request was made after re-authenticating but
4491-
# still failed. Bailing out.
4492-
logger.debug(
4493-
f"[{request_id}] The last request failed after "
4494-
"re-authenticating: {e}\n"
4495-
"Bailing out..."
4496-
)
4497-
raise CredentialsNotValid(
4498-
"The current credentials are no longer valid. Please "
4499-
"log in again using 'zenml login'."
4500-
) from e
4458+
with self._session_lock:
4459+
if self._last_authenticated != last_authenticated:
4460+
# Another thread has re-authenticated since the last
4461+
# request. We simply retry the request.
4462+
continue
4463+
4464+
if self._api_token is None:
4465+
# The last request was not authenticated with an API
4466+
# token at all. We authenticate here and then try the
4467+
# request again, this time with a valid API token in the
4468+
# header.
4469+
logger.debug(
4470+
f"[{request_id}] The last request was not "
4471+
f"authenticated: {e}\n"
4472+
"Re-authenticating and retrying..."
4473+
)
4474+
self.authenticate()
4475+
elif not credentials_store.can_login(self.url):
4476+
# The request failed either because we're not
4477+
# authenticated or our current credentials are not valid
4478+
# anymore.
4479+
logger.error(
4480+
"The current token is no longer valid, and "
4481+
"it is not possible to generate a new token using the "
4482+
"configured credentials. Please run "
4483+
f"`zenml login {self.url}` to "
4484+
"re-authenticate to the server or authenticate using "
4485+
"an API key. See "
4486+
"https://docs.zenml.io/deploying-zenml/connecting-to-zenml/connect-with-a-service-account "
4487+
"for more information."
4488+
)
4489+
# Clear the current token from the credentials store to
4490+
# force a new authentication flow next time.
4491+
get_credentials_store().clear_token(self.url)
4492+
raise e
4493+
elif not re_authenticated:
4494+
# The last request was authenticated with an API token
4495+
# that was rejected by the server. We attempt a
4496+
# re-authentication here and then retry the request.
4497+
logger.debug(
4498+
f"[{request_id}] The last request was authenticated "
4499+
"with an API token that was rejected by the server: "
4500+
f"{e}\n"
4501+
"Re-authenticating and retrying..."
4502+
)
4503+
re_authenticated = True
4504+
self.authenticate(
4505+
# Ignore the current token and force a re-authentication
4506+
force=True
4507+
)
4508+
else:
4509+
# The last request was made after re-authenticating but
4510+
# still failed. Bailing out.
4511+
logger.debug(
4512+
f"[{request_id}] The last request failed after "
4513+
"re-authenticating: {e}\n"
4514+
"Bailing out..."
4515+
)
4516+
raise CredentialsNotValid(
4517+
"The current credentials are no longer valid. Please "
4518+
"log in again using 'zenml login'."
4519+
) from e
45014520
finally:
45024521
end_time = time.time()
45034522
duration = (end_time - start_time) * 1000

0 commit comments

Comments
 (0)