18
18
import time
19
19
from datetime import datetime
20
20
from pathlib import Path
21
+ from threading import RLock
21
22
from typing import (
22
23
Any ,
23
24
ClassVar ,
39
40
BaseModel ,
40
41
ConfigDict ,
41
42
Field ,
43
+ PrivateAttr ,
42
44
ValidationError ,
43
45
field_validator ,
44
46
model_validator ,
272
274
replace_localhost_with_internal_hostname ,
273
275
)
274
276
from zenml .utils .pydantic_utils import before_validator_handler
277
+ from zenml .utils .time_utils import utc_now
275
278
from zenml .zen_server .exceptions import exception_from_response
276
279
from zenml .zen_stores .base_zen_store import BaseZenStore
277
280
@@ -440,6 +443,8 @@ class RestZenStore(BaseZenStore):
440
443
_api_token : Optional [APIToken ] = None
441
444
_session : Optional [requests .Session ] = None
442
445
_server_info : Optional [ServerModel ] = None
446
+ _session_lock : RLock = PrivateAttr (default_factory = RLock )
447
+ _last_authenticated : Optional [datetime ] = None
443
448
444
449
# ====================================
445
450
# ZenML Store interface implementation
@@ -4203,74 +4208,79 @@ def session(self) -> requests.Session:
4203
4208
Returns:
4204
4209
A requests session.
4205
4210
"""
4206
- if self ._session is None :
4207
- # We only need to initialize the session once over the lifetime
4208
- # of the client. We can swap the token out when it expires.
4209
- if self .config .verify_ssl is False :
4210
- urllib3 .disable_warnings (
4211
- urllib3 .exceptions .InsecureRequestWarning
4212
- )
4211
+ with self ._session_lock :
4212
+ if self ._session is None :
4213
+ # We only need to initialize the session once over the lifetime
4214
+ # of the client. We can swap the token out when it expires.
4215
+ if self .config .verify_ssl is False :
4216
+ urllib3 .disable_warnings (
4217
+ urllib3 .exceptions .InsecureRequestWarning
4218
+ )
4213
4219
4214
- self ._session = requests .Session ()
4215
- # Retries are triggered for all HTTP methods (GET, HEAD, POST, PUT,
4216
- # PATCH, OPTIONS and DELETE) on specific HTTP status codes:
4217
- #
4218
- # 408: Request Timeout.
4219
- # 429: Too Many Requests.
4220
- # 502: Bad Gateway.
4221
- # 503: Service Unavailable.
4222
- # 504: Gateway Timeout
4223
- #
4224
- # This also handles connection level errors, if a connection attempt
4225
- # fails due to transient issues like:
4226
- #
4227
- # DNS resolution errors.
4228
- # Connection timeouts.
4229
- # Network disruptions.
4230
- #
4231
- # Additional errors retried:
4232
- #
4233
- # Read Timeouts: If the server does not send a response within
4234
- # the timeout period.
4235
- # Connection Refused: If the server refuses the connection.
4236
- #
4237
- retries = Retry (
4238
- connect = 5 ,
4239
- read = 8 ,
4240
- redirect = 3 ,
4241
- status = 10 ,
4242
- allowed_methods = [
4243
- "HEAD" ,
4244
- "GET" ,
4245
- "POST" ,
4246
- "PUT" ,
4247
- "PATCH" ,
4248
- "DELETE" ,
4249
- "OPTIONS" ,
4250
- ],
4251
- status_forcelist = [
4252
- 408 , # Request Timeout
4253
- 429 , # Too Many Requests
4254
- 502 , # Bad Gateway
4255
- 503 , # Service Unavailable
4256
- 504 , # Gateway Timeout
4257
- ],
4258
- other = 3 ,
4259
- backoff_factor = 1 ,
4260
- )
4261
- self ._session .mount ("https://" , HTTPAdapter (max_retries = retries ))
4262
- self ._session .mount ("http://" , HTTPAdapter (max_retries = retries ))
4263
- self ._session .verify = self .config .verify_ssl
4264
- # Use a custom user agent to identify the ZenML client in the server
4265
- # logs.
4266
- self ._session .headers .update (
4267
- {"User-Agent" : "zenml/" + zenml .__version__ }
4268
- )
4220
+ self ._session = requests .Session ()
4221
+ # Retries are triggered for all HTTP methods (GET, HEAD, POST, PUT,
4222
+ # PATCH, OPTIONS and DELETE) on specific HTTP status codes:
4223
+ #
4224
+ # 408: Request Timeout.
4225
+ # 429: Too Many Requests.
4226
+ # 502: Bad Gateway.
4227
+ # 503: Service Unavailable.
4228
+ # 504: Gateway Timeout
4229
+ #
4230
+ # This also handles connection level errors, if a connection attempt
4231
+ # fails due to transient issues like:
4232
+ #
4233
+ # DNS resolution errors.
4234
+ # Connection timeouts.
4235
+ # Network disruptions.
4236
+ #
4237
+ # Additional errors retried:
4238
+ #
4239
+ # Read Timeouts: If the server does not send a response within
4240
+ # the timeout period.
4241
+ # Connection Refused: If the server refuses the connection.
4242
+ #
4243
+ retries = Retry (
4244
+ connect = 5 ,
4245
+ read = 8 ,
4246
+ redirect = 3 ,
4247
+ status = 10 ,
4248
+ allowed_methods = [
4249
+ "HEAD" ,
4250
+ "GET" ,
4251
+ "POST" ,
4252
+ "PUT" ,
4253
+ "PATCH" ,
4254
+ "DELETE" ,
4255
+ "OPTIONS" ,
4256
+ ],
4257
+ status_forcelist = [
4258
+ 408 , # Request Timeout
4259
+ 429 , # Too Many Requests
4260
+ 502 , # Bad Gateway
4261
+ 503 , # Service Unavailable
4262
+ 504 , # Gateway Timeout
4263
+ ],
4264
+ other = 3 ,
4265
+ backoff_factor = 1 ,
4266
+ )
4267
+ self ._session .mount (
4268
+ "https://" , HTTPAdapter (max_retries = retries )
4269
+ )
4270
+ self ._session .mount (
4271
+ "http://" , HTTPAdapter (max_retries = retries )
4272
+ )
4273
+ self ._session .verify = self .config .verify_ssl
4274
+ # Use a custom user agent to identify the ZenML client in the server
4275
+ # logs.
4276
+ self ._session .headers .update (
4277
+ {"User-Agent" : "zenml/" + zenml .__version__ }
4278
+ )
4269
4279
4270
- # Note that we return an unauthenticated session here. An API token
4271
- # is only fetched and set in the authorization header when and if it is
4272
- # needed.
4273
- return self ._session
4280
+ # Note that we return an unauthenticated session here. An API token
4281
+ # is only fetched and set in the authorization header when and if it is
4282
+ # needed.
4283
+ return self ._session
4274
4284
4275
4285
def authenticate (self , force : bool = False ) -> None :
4276
4286
"""Authenticate or re-authenticate to the ZenML server.
@@ -4325,6 +4335,7 @@ def authenticate(self, force: bool = False) -> None:
4325
4335
{"Authorization" : "Bearer " + new_api_token }
4326
4336
)
4327
4337
logger .debug (f"Authenticated to { self .url } " )
4338
+ self ._last_authenticated = utc_now ()
4328
4339
4329
4340
@staticmethod
4330
4341
def _handle_response (response : requests .Response ) -> Json :
@@ -4391,9 +4402,9 @@ def _request(
4391
4402
CredentialsNotValid: if the request fails due to invalid
4392
4403
client credentials.
4393
4404
"""
4394
- self . session . headers . update (
4395
- { source_context .name : source_context .get ().value }
4396
- )
4405
+ request_headers = {
4406
+ source_context .name : source_context .get ().value ,
4407
+ }
4397
4408
4398
4409
# If the server replies with a credentials validation (401 Unauthorized)
4399
4410
# error, we (re-)authenticate and retry the request here in the
@@ -4414,19 +4425,21 @@ def _request(
4414
4425
while True :
4415
4426
# Add a request ID to the request headers
4416
4427
request_id = str (uuid4 ())[:8 ]
4417
- self . session . headers . update ({ "X-Request-ID" : request_id })
4428
+ request_headers [ "X-Request-ID" ] = request_id
4418
4429
# Add an idempotency key to the request headers to ensure that
4419
4430
# requests are idempotent.
4420
- self . session . headers . update ({ "Idempotency-Key" : str (uuid4 ())} )
4431
+ request_headers [ "Idempotency-Key" ] = str (uuid4 ())
4421
4432
4422
4433
start_time = time .time ()
4423
4434
logger .debug (f"[{ request_id } ] { method } { path } started..." )
4424
4435
status_code = "failed"
4436
+ last_authenticated = self ._last_authenticated
4425
4437
4426
4438
try :
4427
4439
response = self .session .request (
4428
4440
method ,
4429
4441
url ,
4442
+ headers = request_headers ,
4430
4443
params = params if params else {},
4431
4444
verify = self .config .verify_ssl ,
4432
4445
timeout = timeout or self .config .http_timeout ,
@@ -4442,62 +4455,68 @@ def _request(
4442
4455
# authenticated at all.
4443
4456
credentials_store = get_credentials_store ()
4444
4457
4445
- if self ._api_token is None :
4446
- # The last request was not authenticated with an API
4447
- # token at all. We authenticate here and then try the
4448
- # request again, this time with a valid API token in the
4449
- # header.
4450
- logger .debug (
4451
- f"[{ request_id } ] The last request was not "
4452
- f"authenticated: { e } \n "
4453
- "Re-authenticating and retrying..."
4454
- )
4455
- self .authenticate ()
4456
- elif not credentials_store .can_login (self .url ):
4457
- # The request failed either because we're not
4458
- # authenticated or our current credentials are not valid
4459
- # anymore.
4460
- logger .error (
4461
- "The current token is no longer valid, and "
4462
- "it is not possible to generate a new token using the "
4463
- "configured credentials. Please run "
4464
- f"`zenml login { self .url } ` to "
4465
- "re-authenticate to the server or authenticate using "
4466
- "an API key. See "
4467
- "https://docs.zenml.io/deploying-zenml/connecting-to-zenml/connect-with-a-service-account "
4468
- "for more information."
4469
- )
4470
- # Clear the current token from the credentials store to
4471
- # force a new authentication flow next time.
4472
- get_credentials_store ().clear_token (self .url )
4473
- raise e
4474
- elif not re_authenticated :
4475
- # The last request was authenticated with an API token
4476
- # that was rejected by the server. We attempt a
4477
- # re-authentication here and then retry the request.
4478
- logger .debug (
4479
- f"[{ request_id } ] The last request was authenticated "
4480
- "with an API token that was rejected by the server: "
4481
- f"{ e } \n "
4482
- "Re-authenticating and retrying..."
4483
- )
4484
- re_authenticated = True
4485
- self .authenticate (
4486
- # Ignore the current token and force a re-authentication
4487
- force = True
4488
- )
4489
- else :
4490
- # The last request was made after re-authenticating but
4491
- # still failed. Bailing out.
4492
- logger .debug (
4493
- f"[{ request_id } ] The last request failed after "
4494
- "re-authenticating: {e}\n "
4495
- "Bailing out..."
4496
- )
4497
- raise CredentialsNotValid (
4498
- "The current credentials are no longer valid. Please "
4499
- "log in again using 'zenml login'."
4500
- ) from e
4458
+ with self ._session_lock :
4459
+ if self ._last_authenticated != last_authenticated :
4460
+ # Another thread has re-authenticated since the last
4461
+ # request. We simply retry the request.
4462
+ continue
4463
+
4464
+ if self ._api_token is None :
4465
+ # The last request was not authenticated with an API
4466
+ # token at all. We authenticate here and then try the
4467
+ # request again, this time with a valid API token in the
4468
+ # header.
4469
+ logger .debug (
4470
+ f"[{ request_id } ] The last request was not "
4471
+ f"authenticated: { e } \n "
4472
+ "Re-authenticating and retrying..."
4473
+ )
4474
+ self .authenticate ()
4475
+ elif not credentials_store .can_login (self .url ):
4476
+ # The request failed either because we're not
4477
+ # authenticated or our current credentials are not valid
4478
+ # anymore.
4479
+ logger .error (
4480
+ "The current token is no longer valid, and "
4481
+ "it is not possible to generate a new token using the "
4482
+ "configured credentials. Please run "
4483
+ f"`zenml login { self .url } ` to "
4484
+ "re-authenticate to the server or authenticate using "
4485
+ "an API key. See "
4486
+ "https://docs.zenml.io/deploying-zenml/connecting-to-zenml/connect-with-a-service-account "
4487
+ "for more information."
4488
+ )
4489
+ # Clear the current token from the credentials store to
4490
+ # force a new authentication flow next time.
4491
+ get_credentials_store ().clear_token (self .url )
4492
+ raise e
4493
+ elif not re_authenticated :
4494
+ # The last request was authenticated with an API token
4495
+ # that was rejected by the server. We attempt a
4496
+ # re-authentication here and then retry the request.
4497
+ logger .debug (
4498
+ f"[{ request_id } ] The last request was authenticated "
4499
+ "with an API token that was rejected by the server: "
4500
+ f"{ e } \n "
4501
+ "Re-authenticating and retrying..."
4502
+ )
4503
+ re_authenticated = True
4504
+ self .authenticate (
4505
+ # Ignore the current token and force a re-authentication
4506
+ force = True
4507
+ )
4508
+ else :
4509
+ # The last request was made after re-authenticating but
4510
+ # still failed. Bailing out.
4511
+ logger .debug (
4512
+ f"[{ request_id } ] The last request failed after "
4513
+ "re-authenticating: {e}\n "
4514
+ "Bailing out..."
4515
+ )
4516
+ raise CredentialsNotValid (
4517
+ "The current credentials are no longer valid. Please "
4518
+ "log in again using 'zenml login'."
4519
+ ) from e
4501
4520
finally :
4502
4521
end_time = time .time ()
4503
4522
duration = (end_time - start_time ) * 1000
0 commit comments