Skip to content

Commit 026dbe1

Browse files
committed
WIP: try to isolate redis-entraid in tests
1 parent 2b65eff commit 026dbe1

File tree

7 files changed

+165
-285
lines changed

7 files changed

+165
-285
lines changed

.github/actions/run-tests/action.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ runs:
3939
pip install -U setuptools wheel
4040
pip install -r requirements.txt
4141
pip install -r dev_requirements.txt
42+
pip install -e .
4243
if [ "${{inputs.parser-backend}}" == "hiredis" ]; then
4344
pip install "hiredis${{inputs.hiredis-version}}"
4445
echo "PARSER_BACKEND=$(echo "${{inputs.parser-backend}}_${{inputs.hiredis-version}}" | sed 's/[^a-zA-Z0-9]/_/g')" >> $GITHUB_ENV

dev_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ uvloop
1616
vulture>=2.3.0
1717
wheel>=0.30.0
1818
numpy>=1.24.0
19-
redis-entraid==0.3.0b1
19+
#redis-entraid==0.3.0b1

tests/conftest.py

Lines changed: 8 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import random
55
import time
66
from datetime import datetime, timezone
7-
from enum import Enum
87
from typing import Callable, TypeVar, Union
98
from unittest import mock
109
from unittest.mock import Mock
@@ -17,7 +16,6 @@
1716
from redis import Sentinel
1817
from redis.auth.idp import IdentityProviderInterface
1918
from redis.auth.token import JWToken
20-
from redis.auth.token_manager import RetryPolicy, TokenManagerConfig
2119
from redis.backoff import NoBackoff
2220
from redis.cache import (
2321
CacheConfig,
@@ -30,22 +28,6 @@
3028
from redis.credentials import CredentialProvider
3129
from redis.exceptions import RedisClusterException
3230
from redis.retry import Retry
33-
from redis_entraid.cred_provider import (
34-
DEFAULT_DELAY_IN_MS,
35-
DEFAULT_EXPIRATION_REFRESH_RATIO,
36-
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
37-
DEFAULT_MAX_ATTEMPTS,
38-
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
39-
EntraIdCredentialsProvider,
40-
)
41-
from redis_entraid.identity_provider import (
42-
ManagedIdentityIdType,
43-
ManagedIdentityProviderConfig,
44-
ManagedIdentityType,
45-
ServicePrincipalIdentityProviderConfig,
46-
_create_provider_from_managed_identity,
47-
_create_provider_from_service_principal,
48-
)
4931
from tests.ssl_utils import get_tls_certificates
5032

5133
REDIS_INFO = {}
@@ -61,11 +43,6 @@
6143
_TestDecorator = Callable[[_DecoratedTest], _DecoratedTest]
6244

6345

64-
class AuthType(Enum):
65-
MANAGED_IDENTITY = "managed_identity"
66-
SERVICE_PRINCIPAL = "service_principal"
67-
68-
6946
# Taken from python3.9
7047
class BooleanOptionalAction(argparse.Action):
7148
def __init__(
@@ -623,124 +600,21 @@ def mock_identity_provider() -> IdentityProviderInterface:
623600
return mock_provider
624601

625602

626-
def identity_provider(request) -> IdentityProviderInterface:
627-
if hasattr(request, "param"):
628-
kwargs = request.param.get("idp_kwargs", {})
629-
else:
630-
kwargs = {}
631-
632-
if request.param.get("mock_idp", None) is not None:
633-
return mock_identity_provider()
634-
635-
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
636-
config = get_identity_provider_config(request=request)
637-
638-
if auth_type == "MANAGED_IDENTITY":
639-
return _create_provider_from_managed_identity(config)
640-
641-
return _create_provider_from_service_principal(config)
642-
643-
644-
def get_identity_provider_config(
645-
request,
646-
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
647-
if hasattr(request, "param"):
648-
kwargs = request.param.get("idp_kwargs", {})
649-
else:
650-
kwargs = {}
651-
652-
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
653-
654-
if auth_type == AuthType.MANAGED_IDENTITY:
655-
return _get_managed_identity_provider_config(request)
656-
657-
return _get_service_principal_provider_config(request)
658-
659-
660-
def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
661-
resource = os.getenv("AZURE_RESOURCE")
662-
id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)
663-
664-
if hasattr(request, "param"):
665-
kwargs = request.param.get("idp_kwargs", {})
666-
else:
667-
kwargs = {}
668-
669-
identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
670-
id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)
671-
672-
return ManagedIdentityProviderConfig(
673-
identity_type=identity_type,
674-
resource=resource,
675-
id_type=id_type,
676-
id_value=id_value,
677-
kwargs=kwargs,
678-
)
679-
680-
681-
def _get_service_principal_provider_config(
682-
request,
683-
) -> ServicePrincipalIdentityProviderConfig:
684-
client_id = os.getenv("AZURE_CLIENT_ID")
685-
client_credential = os.getenv("AZURE_CLIENT_SECRET")
686-
tenant_id = os.getenv("AZURE_TENANT_ID")
687-
scopes = os.getenv("AZURE_REDIS_SCOPES", None)
688-
689-
if hasattr(request, "param"):
690-
kwargs = request.param.get("idp_kwargs", {})
691-
token_kwargs = request.param.get("token_kwargs", {})
692-
timeout = request.param.get("timeout", None)
693-
else:
694-
kwargs = {}
695-
token_kwargs = {}
696-
timeout = None
697-
698-
if isinstance(scopes, str):
699-
scopes = scopes.split(",")
700-
701-
return ServicePrincipalIdentityProviderConfig(
702-
client_id=client_id,
703-
client_credential=client_credential,
704-
scopes=scopes,
705-
timeout=timeout,
706-
token_kwargs=token_kwargs,
707-
tenant_id=tenant_id,
708-
app_kwargs=kwargs,
709-
)
710-
711-
712603
def get_credential_provider(request) -> CredentialProvider:
713604
cred_provider_class = request.param.get("cred_provider_class")
714605
cred_provider_kwargs = request.param.get("cred_provider_kwargs", {})
715606

716-
if cred_provider_class != EntraIdCredentialsProvider:
607+
if not cred_provider_class:
608+
pytest.skip("No credential provider class specified in the test")
609+
610+
# Since we can't import EntraIdCredentialsProvider in this module,
611+
# we'll just check the class name.
612+
if cred_provider_class.__name__ != "EntraIdCredentialsProvider":
717613
return cred_provider_class(**cred_provider_kwargs)
718614

719-
idp = identity_provider(request)
720-
expiration_refresh_ratio = cred_provider_kwargs.get(
721-
"expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
722-
)
723-
lower_refresh_bound_millis = cred_provider_kwargs.get(
724-
"lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
725-
)
726-
max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
727-
delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
728-
729-
token_mgr_config = TokenManagerConfig(
730-
expiration_refresh_ratio=expiration_refresh_ratio,
731-
lower_refresh_bound_millis=lower_refresh_bound_millis,
732-
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa
733-
retry_policy=RetryPolicy(
734-
max_attempts=max_attempts,
735-
delay_in_ms=delay_in_ms,
736-
),
737-
)
615+
from tests.entraid_utils import get_entra_id_credentials_provider
616+
return get_entra_id_credentials_provider(request, cred_provider_kwargs)
738617

739-
return EntraIdCredentialsProvider(
740-
identity_provider=idp,
741-
token_manager_config=token_mgr_config,
742-
initial_delay_in_ms=delay_in_ms,
743-
)
744618

745619

746620
@pytest.fixture()

tests/entraid_utils.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
from enum import Enum
3+
from typing import Union
4+
5+
from redis_entraid.cred_provider import (
6+
DEFAULT_DELAY_IN_MS,
7+
DEFAULT_EXPIRATION_REFRESH_RATIO,
8+
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
9+
DEFAULT_MAX_ATTEMPTS,
10+
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
11+
EntraIdCredentialsProvider,
12+
)
13+
from redis_entraid.identity_provider import (
14+
ManagedIdentityIdType,
15+
ManagedIdentityProviderConfig,
16+
ManagedIdentityType,
17+
ServicePrincipalIdentityProviderConfig,
18+
_create_provider_from_managed_identity,
19+
_create_provider_from_service_principal,
20+
)
21+
22+
from redis.auth.idp import IdentityProviderInterface
23+
from redis.auth.token_manager import TokenManagerConfig, RetryPolicy
24+
from tests.conftest import mock_identity_provider
25+
26+
27+
class AuthType(Enum):
28+
MANAGED_IDENTITY = "managed_identity"
29+
SERVICE_PRINCIPAL = "service_principal"
30+
31+
32+
33+
34+
def identity_provider(request) -> IdentityProviderInterface:
35+
if hasattr(request, "param"):
36+
kwargs = request.param.get("idp_kwargs", {})
37+
else:
38+
kwargs = {}
39+
40+
if request.param.get("mock_idp", None) is not None:
41+
return mock_identity_provider()
42+
43+
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
44+
config = get_identity_provider_config(request=request)
45+
46+
if auth_type == "MANAGED_IDENTITY":
47+
return _create_provider_from_managed_identity(config)
48+
49+
return _create_provider_from_service_principal(config)
50+
51+
52+
def get_identity_provider_config(
53+
request,
54+
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
55+
if hasattr(request, "param"):
56+
kwargs = request.param.get("idp_kwargs", {})
57+
else:
58+
kwargs = {}
59+
60+
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
61+
62+
if auth_type == AuthType.MANAGED_IDENTITY:
63+
return _get_managed_identity_provider_config(request)
64+
65+
return _get_service_principal_provider_config(request)
66+
67+
68+
def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
69+
resource = os.getenv("AZURE_RESOURCE")
70+
id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)
71+
72+
if hasattr(request, "param"):
73+
kwargs = request.param.get("idp_kwargs", {})
74+
else:
75+
kwargs = {}
76+
77+
identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
78+
id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)
79+
80+
return ManagedIdentityProviderConfig(
81+
identity_type=identity_type,
82+
resource=resource,
83+
id_type=id_type,
84+
id_value=id_value,
85+
kwargs=kwargs,
86+
)
87+
88+
89+
def _get_service_principal_provider_config(
90+
request,
91+
) -> ServicePrincipalIdentityProviderConfig:
92+
client_id = os.getenv("AZURE_CLIENT_ID")
93+
client_credential = os.getenv("AZURE_CLIENT_SECRET")
94+
tenant_id = os.getenv("AZURE_TENANT_ID")
95+
scopes = os.getenv("AZURE_REDIS_SCOPES", None)
96+
97+
if hasattr(request, "param"):
98+
kwargs = request.param.get("idp_kwargs", {})
99+
token_kwargs = request.param.get("token_kwargs", {})
100+
timeout = request.param.get("timeout", None)
101+
else:
102+
kwargs = {}
103+
token_kwargs = {}
104+
timeout = None
105+
106+
if isinstance(scopes, str):
107+
scopes = scopes.split(",")
108+
109+
return ServicePrincipalIdentityProviderConfig(
110+
client_id=client_id,
111+
client_credential=client_credential,
112+
scopes=scopes,
113+
timeout=timeout,
114+
token_kwargs=token_kwargs,
115+
tenant_id=tenant_id,
116+
app_kwargs=kwargs,
117+
)
118+
119+
120+
def get_entra_id_credentials_provider(request, cred_provider_kwargs):
121+
idp = identity_provider(request)
122+
expiration_refresh_ratio = cred_provider_kwargs.get(
123+
"expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
124+
)
125+
lower_refresh_bound_millis = cred_provider_kwargs.get(
126+
"lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
127+
)
128+
max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
129+
delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
130+
token_mgr_config = TokenManagerConfig(
131+
expiration_refresh_ratio=expiration_refresh_ratio,
132+
lower_refresh_bound_millis=lower_refresh_bound_millis,
133+
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa
134+
retry_policy=RetryPolicy(
135+
max_attempts=max_attempts,
136+
delay_in_ms=delay_in_ms,
137+
),
138+
)
139+
return EntraIdCredentialsProvider(
140+
identity_provider=idp,
141+
token_manager_config=token_mgr_config,
142+
initial_delay_in_ms=delay_in_ms,
143+
)

0 commit comments

Comments
 (0)