From ddb54e7eceffec960853c73c8b92b2d22713af00 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 1 Apr 2020 16:44:33 -0700 Subject: [PATCH 1/4] feat(auth): Added delete_saml_provider_config() API --- firebase_admin/_auth_client.py | 13 +++++++++++++ firebase_admin/_auth_providers.py | 4 ++++ firebase_admin/auth.py | 16 ++++++++++++++++ tests/test_auth_providers.py | 25 +++++++++++++++++++++---- tests/test_tenant_mgt.py | 12 ++++++++++++ 5 files changed, 66 insertions(+), 4 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 9af0db56f..6de8e9ec4 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -476,6 +476,19 @@ def update_saml_provider_config( x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, callback_url=callback_url, display_name=display_name, enabled=enabled) + def delete_saml_provider_config(self, provider_id): + """Deletes the SAMLProviderConfig with the given ID. + + Args: + provider_id: Provider ID string. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the SAML provider. + """ + self._provider_manager.delete_saml_provider_config(provider_id) + def _check_jwt_revoked(self, verified_claims, exc_type, label): user = self.get_user(verified_claims.get('uid')) if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 6cec0f29a..c8378f6ef 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -151,6 +151,10 @@ def update_saml_provider_config( body = self._make_request('patch', url, json=req, params=params) return SAMLProviderConfig(body) + def delete_saml_provider_config(self, provider_id): + _validate_saml_provider_id(provider_id) + self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id)) + def _make_request(self, method, path, **kwargs): url = '{0}{1}'.format(self.base_url, path) try: diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index c422c3ab7..b26ec10b2 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -633,3 +633,19 @@ def update_saml_provider_config( provider_id, idp_entity_id=idp_entity_id, sso_url=sso_url, x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, callback_url=callback_url, display_name=display_name, enabled=enabled) + + +def delete_saml_provider_config(provider_id, app=None): + """Deletes the SAMLProviderConfig with the given ID. + + Args: + provider_id: Provider ID string. + app: An App instance (optional). + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the SAML provider. + """ + client = _get_client(app) + client.delete_saml_provider_config(provider_id) diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index 9ef59fbff..945410262 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -33,6 +33,8 @@ } }""" +INVALID_PROVIDER_IDS = [None, True, False, 1, 0, list(), tuple(), dict(), ''] + @pytest.fixture(scope='module') def user_mgt_app(): @@ -79,10 +81,8 @@ class TestSAMLProviderConfig: } } - @pytest.mark.parametrize('provider_id', [ - None, True, False, 1, 0, list(), tuple(), dict(), '', 'oidc.provider' - ]) - def test_invalid_provider_id(self, user_mgt_app, provider_id): + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['oidc.provider']) + def test_get_invalid_provider_id(self, user_mgt_app, provider_id): with pytest.raises(ValueError) as excinfo: auth.get_saml_provider_config(provider_id, app=user_mgt_app) @@ -239,6 +239,23 @@ def test_update_empty_values(self, user_mgt_app): got = json.loads(req.body.decode()) assert got == {'displayName': None, 'enabled': False} + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['oidc.provider']) + def test_delete_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.delete_saml_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid SAML provider ID') + + def test_delete(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, '{}') + + auth.delete_saml_provider_config('saml.provider', app=user_mgt_app) + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + def test_config_not_found(self, user_mgt_app): _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index e08eaf8de..ee6fe8bf0 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -761,6 +761,18 @@ def test_update_saml_provider_config(self, tenant_mgt_app): recorder, url, SAML_PROVIDER_CONFIG_REQUEST, method='PATCH', prefix=PROVIDER_MGT_URL_PREFIX) + def test_delete_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + client.delete_saml_provider_config('saml.provider') + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( + PROVIDER_MGT_URL_PREFIX) + def test_tenant_not_found(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) _instrument_user_mgt(client, 500, TENANT_NOT_FOUND_RESPONSE) From 5ef2d01d3f3176c75781cd0935e20bab875d365e Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 1 Apr 2020 17:25:28 -0700 Subject: [PATCH 2/4] Preliminary list provider config impl --- firebase_admin/_auth_client.py | 4 + firebase_admin/_auth_providers.py | 113 +++++++++++++++++++++ firebase_admin/auth.py | 10 ++ tests/data/list_saml_provider_configs.json | 40 ++++++++ tests/test_auth_providers.py | 41 +++++++- 5 files changed, 206 insertions(+), 2 deletions(-) create mode 100644 tests/data/list_saml_provider_configs.json diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 6de8e9ec4..838590366 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -489,6 +489,10 @@ def delete_saml_provider_config(self, provider_id): """ self._provider_manager.delete_saml_provider_config(provider_id) + def list_saml_provider_configs( + self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + return self._provider_manager.list_saml_provider_configs(page_token, max_results) + def _check_jwt_revoked(self, verified_claims, exc_type, label): user = self.get_user(verified_claims.get('uid')) if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index c8378f6ef..f4cdc99b3 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -22,6 +22,9 @@ from firebase_admin import _user_mgt +MAX_LIST_CONFIGS_RESULTS = 100 + + class ProviderConfig: """Parent type for all authentication provider config types.""" @@ -69,6 +72,94 @@ def rp_entity_id(self): return self._data.get('spConfig', {})['spEntityId'] +class ListProviderConfigsPage: + """Represents a page of AuthProviderConfig instances retrieved from a Firebase project. + + Provides methods for traversing the provider configs included in this page, as well as + retrieving subsequent pages. The iterator returned by ``iterate_all()`` can be used to iterate + through all provider configs in the Firebase project starting from this page. + """ + + def __init__(self, download, page_token, max_results, result_cls, config_key): + self._download = download + self._max_results = max_results + self._current = download(page_token, max_results) + self._result_cls = result_cls + self._config_key = config_key + + @property + def provider_configs(self): + """A list of ``AuthProviderConfig`` instances available in this page.""" + return [self._result_cls(config) for config in self._current.get(self._config_key, [])] + + @property + def next_page_token(self): + """Page token string for the next page (empty string indicates no more pages).""" + return self._current.get('nextPageToken', '') + + @property + def has_next_page(self): + """A boolean indicating whether more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of provider configs, if available. + + Returns: + ListProviderConfigsPage: Next page of provider configs, or None if this is the last + page. + """ + if self.has_next_page: + return ListProviderConfigsPage( + self._download, self.next_page_token, self._max_results, + self._result_cls, self._config_key) + return None + + def iterate_all(self): + """Retrieves an iterator for provider configs. + + Returned iterator will iterate through all the provider configs in the Firebase project + starting from this page. The iterator will never buffer more than one page of configs + in memory at a time. + + Returns: + iterator: An iterator of AuthProviderConfig instances. + """ + return _ProviderConfigIterator(self) + + +class _ProviderConfigIterator: + """An iterator that allows iterating over provider configs, one at a time. + + This implementation loads a page of configs into memory, and iterates on them. When the whole + page has been traversed, it loads another page. This class never keeps more than one page + of entries in memory. + """ + + def __init__(self, current_page): + if not current_page: + raise ValueError('Current page must not be None.') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self._current_page.provider_configs): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self._current_page.provider_configs): + result = self._current_page.provider_configs[self._index] + self._index += 1 + return result + raise StopIteration + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + class ProviderConfigClient: """Client for managing Auth provider configurations.""" @@ -155,6 +246,28 @@ def delete_saml_provider_config(self, provider_id): _validate_saml_provider_id(provider_id) self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id)) + def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + return ListProviderConfigsPage( + self._fetch_saml_provider_configs, page_token, max_results, + result_cls=SAMLProviderConfig, config_key='inboundSamlConfigs') + + def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + """Fetches a page of SAML provider configs""" + if page_token is not None: + if not isinstance(page_token, str) or not page_token: + raise ValueError('Page token must be a non-empty string.') + if not isinstance(max_results, int): + raise ValueError('Max results must be an integer.') + if max_results < 1 or max_results > MAX_LIST_CONFIGS_RESULTS: + raise ValueError( + 'Max results must be a positive integer less than ' + '{0}.'.format(MAX_LIST_CONFIGS_RESULTS)) + + params = {'pageSize': max_results} + if page_token: + params['pageToken'] = page_token + return self._make_request('get', '/inboundSamlConfigs', params=params) + def _make_request(self, method, path, **kwargs): url = '{0}{1}'.format(self.base_url, path) try: diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index b26ec10b2..d0e162151 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -46,6 +46,7 @@ 'InvalidDynamicLinkDomainError', 'InvalidIdTokenError', 'InvalidSessionCookieError', + 'ListProviderConfigsPage', 'ListUsersPage', 'PhoneNumberAlreadyExistsError', 'ProviderConfig', @@ -67,6 +68,7 @@ 'create_saml_provider_config', 'create_session_cookie', 'create_user', + 'delete_saml_provider_config', 'delete_user', 'generate_email_verification_link', 'generate_password_reset_link', @@ -76,6 +78,7 @@ 'get_user_by_email', 'get_user_by_phone_number', 'import_users', + 'list_saml_provider_configs', 'list_users', 'revoke_refresh_tokens', 'set_custom_user_claims', @@ -100,6 +103,7 @@ InvalidDynamicLinkDomainError = _auth_utils.InvalidDynamicLinkDomainError InvalidIdTokenError = _auth_utils.InvalidIdTokenError InvalidSessionCookieError = _token_gen.InvalidSessionCookieError +ListProviderConfigsPage = _auth_providers.ListProviderConfigsPage ListUsersPage = _user_mgt.ListUsersPage PhoneNumberAlreadyExistsError = _auth_utils.PhoneNumberAlreadyExistsError ProviderConfig = _auth_providers.ProviderConfigClient @@ -649,3 +653,9 @@ def delete_saml_provider_config(provider_id, app=None): """ client = _get_client(app) client.delete_saml_provider_config(provider_id) + + +def list_saml_provider_configs( + page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + client = _get_client(app) + return client.list_saml_provider_configs(page_token, max_results) diff --git a/tests/data/list_saml_provider_configs.json b/tests/data/list_saml_provider_configs.json new file mode 100644 index 000000000..7af13599c --- /dev/null +++ b/tests/data/list_saml_provider_configs.json @@ -0,0 +1,40 @@ +{ + "inboundSamlConfigs": [ + { + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider0", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true + }, + { + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider1", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true + } + ] +} \ No newline at end of file diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index 945410262..d12329150 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -26,6 +26,7 @@ USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' SAML_PROVIDER_CONFIG_RESPONSE = testutils.resource('saml_provider_config.json') +LIST_SAML_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_saml_provider_configs.json') CONFIG_NOT_FOUND_RESPONSE = """{ "error": { @@ -268,8 +269,33 @@ def test_config_not_found(self, user_mgt_app): assert excinfo.value.http_response is not None assert excinfo.value.cause is not None - def _assert_provider_config(self, provider_config): - assert provider_config.provider_id == 'saml.provider' + @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 1001, False]) + def test_invalid_max_results(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_saml_provider_configs(max_results=arg, app=user_mgt_app) + + @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 1001, False]) + def test_invalid_page_token(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_saml_provider_configs(page_token=arg, app=user_mgt_app) + + def test_list_single_page(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, LIST_SAML_PROVIDER_CONFIGS_RESPONSE) + page = auth.list_saml_provider_configs(app=user_mgt_app) + + self._assert_page(page) + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs?pageSize=100') + + def _assert_provider_config(self, provider_config, want_id='saml.provider'): + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == want_id assert provider_config.display_name == 'samlProviderName' assert provider_config.enabled is True assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' @@ -277,3 +303,14 @@ def _assert_provider_config(self, provider_config): assert provider_config.x509_certificates == ['CERT1', 'CERT2'] assert provider_config.rp_entity_id == 'RP_ENTITY_ID' assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + + def _assert_page(self, page): + assert isinstance(page, auth.ListProviderConfigsPage) + index = 0 + assert len(page.provider_configs) == 2 + for provider_config in page.provider_configs: + self._assert_provider_config(provider_config, want_id='saml.provider{0}'.format(index)) + index += 1 + + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 From a2e26edde21458379c5df9b35eb2e030d52de805 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 1 Apr 2020 19:09:17 -0700 Subject: [PATCH 3/4] Refactored the common paging logic into base classes --- firebase_admin/_auth_providers.py | 49 ++++++++----------------------- firebase_admin/_auth_utils.py | 36 +++++++++++++++++++++++ firebase_admin/_user_mgt.py | 33 +++------------------ 3 files changed, 53 insertions(+), 65 deletions(-) diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index f4cdc99b3..f3331e8a6 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -80,17 +80,15 @@ class ListProviderConfigsPage: through all provider configs in the Firebase project starting from this page. """ - def __init__(self, download, page_token, max_results, result_cls, config_key): + def __init__(self, download, page_token, max_results): self._download = download self._max_results = max_results self._current = download(page_token, max_results) - self._result_cls = result_cls - self._config_key = config_key @property def provider_configs(self): """A list of ``AuthProviderConfig`` instances available in this page.""" - return [self._result_cls(config) for config in self._current.get(self._config_key, [])] + raise NotImplementedError @property def next_page_token(self): @@ -110,9 +108,7 @@ def get_next_page(self): page. """ if self.has_next_page: - return ListProviderConfigsPage( - self._download, self.next_page_token, self._max_results, - self._result_cls, self._config_key) + return self.__class__(self._download, self.next_page_token, self._max_results) return None def iterate_all(self): @@ -128,36 +124,18 @@ def iterate_all(self): return _ProviderConfigIterator(self) -class _ProviderConfigIterator: - """An iterator that allows iterating over provider configs, one at a time. +class _ListSAMLProviderConfigsPage(ListProviderConfigsPage): - This implementation loads a page of configs into memory, and iterates on them. When the whole - page has been traversed, it loads another page. This class never keeps more than one page - of entries in memory. - """ - - def __init__(self, current_page): - if not current_page: - raise ValueError('Current page must not be None.') - self._current_page = current_page - self._index = 0 + @property + def provider_configs(self): + return [SAMLProviderConfig(data) for data in self._current.get('inboundSamlConfigs', [])] - def next(self): - if self._index == len(self._current_page.provider_configs): - if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() - self._index = 0 - if self._index < len(self._current_page.provider_configs): - result = self._current_page.provider_configs[self._index] - self._index += 1 - return result - raise StopIteration - def __next__(self): - return self.next() +class _ProviderConfigIterator(_auth_utils.PageIterator): - def __iter__(self): - return self + @property + def items(self): + return self._current_page.provider_configs class ProviderConfigClient: @@ -247,9 +225,8 @@ def delete_saml_provider_config(self, provider_id): self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id)) def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): - return ListProviderConfigsPage( - self._fetch_saml_provider_configs, page_token, max_results, - result_cls=SAMLProviderConfig, config_key='inboundSamlConfigs') + return _ListSAMLProviderConfigsPage( + self._fetch_saml_provider_configs, page_token, max_results) def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): """Fetches a page of SAML provider configs""" diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index e05793d8f..f1ce97dee 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -30,6 +30,42 @@ VALID_EMAIL_ACTION_TYPES = set(['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET']) +class PageIterator: + """An iterator that allows iterating over a sequence of items, one at a time. + + This implementation loads a page of items into memory, and iterates on them. When the whole + page has been traversed, it loads another page. This class never keeps more than one page + of entries in memory. + """ + + def __init__(self, current_page): + if not current_page: + raise ValueError('Current page must not be None.') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self.items): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self.items): + result = self.items[self._index] + self._index += 1 + return result + raise StopIteration + + @property + def items(self): + raise NotImplementedError + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + def validate_uid(uid, required=False): if uid is None and not required: return None diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 0f3dc1a94..8b0a81adf 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -639,33 +639,8 @@ def _make_request(self, method, path, **kwargs): raise _auth_utils.handle_auth_backend_error(error) -class _UserIterator: - """An iterator that allows iterating over user accounts, one at a time. +class _UserIterator(_auth_utils.PageIterator): - This implementation loads a page of users into memory, and iterates on them. When the whole - page has been traversed, it loads another page. This class never keeps more than one page - of entries in memory. - """ - - def __init__(self, current_page): - if not current_page: - raise ValueError('Current page must not be None.') - self._current_page = current_page - self._index = 0 - - def next(self): - if self._index == len(self._current_page.users): - if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() - self._index = 0 - if self._index < len(self._current_page.users): - result = self._current_page.users[self._index] - self._index += 1 - return result - raise StopIteration - - def __next__(self): - return self.next() - - def __iter__(self): - return self + @property + def items(self): + return self._current_page.users From 82293f0095d1dcb8ace64cbe7cb820e7fbf73681 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Fri, 3 Apr 2020 14:20:44 -0700 Subject: [PATCH 4/4] Added more tests for list API --- firebase_admin/_auth_client.py | 21 ++++ firebase_admin/_auth_providers.py | 6 +- firebase_admin/auth.py | 22 ++++ tests/data/list_saml_provider_configs.json | 2 +- tests/test_auth_providers.py | 112 +++++++++++++++++++-- tests/test_tenant_mgt.py | 33 +++++- 6 files changed, 180 insertions(+), 16 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 838590366..761c1a1f7 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -491,6 +491,27 @@ def delete_saml_provider_config(self, provider_id): def list_saml_provider_configs( self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + """Retrieves a page of SAML provider configs from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns None. If there are no SAML configs in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + + Returns: + ListProviderConfigsPage: A ListProviderConfigsPage instance. + + Raises: + ValueError: If max_results or page_token are invalid. + FirebaseError: If an error occurs while retrieving the SAML provider configs. + """ return self._provider_manager.list_saml_provider_configs(page_token, max_results) def _check_jwt_revoked(self, verified_claims, exc_type, label): diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index f3331e8a6..9bcb7cc4b 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -237,12 +237,12 @@ def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CON raise ValueError('Max results must be an integer.') if max_results < 1 or max_results > MAX_LIST_CONFIGS_RESULTS: raise ValueError( - 'Max results must be a positive integer less than ' + 'Max results must be a positive integer less than or equal to ' '{0}.'.format(MAX_LIST_CONFIGS_RESULTS)) - params = {'pageSize': max_results} + params = 'pageSize={0}'.format(max_results) if page_token: - params['pageToken'] = page_token + params += '&pageToken={0}'.format(page_token) return self._make_request('get', '/inboundSamlConfigs', params=params) def _make_request(self, method, path, **kwargs): diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index d0e162151..7d11bd58c 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -657,5 +657,27 @@ def delete_saml_provider_config(provider_id, app=None): def list_saml_provider_configs( page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + """Retrieves a page of SAML provider configs from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns None. If there are no SAML configs in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + app: An App instance (optional). + + Returns: + ListProviderConfigsPage: A ListProviderConfigsPage instance. + + Raises: + ValueError: If max_results or page_token are invalid. + FirebaseError: If an error occurs while retrieving the SAML provider configs. + """ client = _get_client(app) return client.list_saml_provider_configs(page_token, max_results) diff --git a/tests/data/list_saml_provider_configs.json b/tests/data/list_saml_provider_configs.json index 7af13599c..b568e1e09 100644 --- a/tests/data/list_saml_provider_configs.json +++ b/tests/data/list_saml_provider_configs.json @@ -37,4 +37,4 @@ "enabled": true } ] -} \ No newline at end of file +} diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index d12329150..f5a66a7c5 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -269,12 +269,12 @@ def test_config_not_found(self, user_mgt_app): assert excinfo.value.http_response is not None assert excinfo.value.cause is not None - @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 1001, False]) + @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) def test_invalid_max_results(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_saml_provider_configs(max_results=arg, app=user_mgt_app) - @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 1001, False]) + @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 101, False]) def test_invalid_page_token(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_saml_provider_configs(page_token=arg, app=user_mgt_app) @@ -284,15 +284,94 @@ def test_list_single_page(self, user_mgt_app): page = auth.list_saml_provider_configs(app=user_mgt_app) self._assert_page(page) - assert page.next_page_token == '' - assert page.has_next_page is False - assert page.get_next_page() is None + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs?pageSize=100') + def test_list_multiple_pages(self, user_mgt_app): + sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) + configs = self._create_list_response(sample_response) + + # Page 1 + response = { + 'inboundSamlConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(max_results=10, app=user_mgt_app) + + self._assert_page(page, next_page_token='token') + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=10'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'inboundSamlConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = page.get_next_page() + + self._assert_page(page, count=1, start=2) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=10&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + def test_paged_iteration(self, user_mgt_app): + sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) + configs = self._create_list_response(sample_response) + + # Page 1 + response = { + 'inboundSamlConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(app=user_mgt_app) + iterator = page.iterate_all() + + for index in range(2): + provider_config = next(iterator) + assert provider_config.provider_id == 'saml.provider{0}'.format(index) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=100'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'inboundSamlConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + + provider_config = next(iterator) + assert provider_config.provider_id == 'saml.provider2' + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=100&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + with pytest.raises(StopIteration): + next(iterator) + + def test_list_empty_response(self, user_mgt_app): + response = {'inboundSamlConfigs': []} + _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(app=user_mgt_app) + assert len(page.provider_configs) == 0 + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 0 + + def test_list_error(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, '{"error":"test"}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.list_saml_provider_configs(app=user_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' + def _assert_provider_config(self, provider_config, want_id='saml.provider'): assert isinstance(provider_config, auth.SAMLProviderConfig) assert provider_config.provider_id == want_id @@ -304,13 +383,26 @@ def _assert_provider_config(self, provider_config, want_id='saml.provider'): assert provider_config.rp_entity_id == 'RP_ENTITY_ID' assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' - def _assert_page(self, page): + def _assert_page(self, page, count=2, start=0, next_page_token=''): assert isinstance(page, auth.ListProviderConfigsPage) - index = 0 - assert len(page.provider_configs) == 2 + index = start + assert len(page.provider_configs) == count for provider_config in page.provider_configs: self._assert_provider_config(provider_config, want_id='saml.provider{0}'.format(index)) index += 1 - provider_configs = list(config for config in page.iterate_all()) - assert len(provider_configs) == 2 + if next_page_token: + assert page.next_page_token == next_page_token + assert page.has_next_page is True + else: + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + + def _create_list_response(self, sample_response, count=3): + configs = [] + for idx in range(count): + config = dict(sample_response) + config['name'] += str(idx) + configs.append(config) + return configs diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index ee6fe8bf0..7cb8e7bab 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -93,6 +93,8 @@ } } +LIST_SAML_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_saml_provider_configs.json') + INVALID_TENANT_IDS = [None, '', 0, 1, True, False, list(), tuple(), dict()] INVALID_BOOLEANS = ['', 1, 0, list(), tuple(), dict()] @@ -773,6 +775,32 @@ def test_delete_saml_provider_config(self, tenant_mgt_app): assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( PROVIDER_MGT_URL_PREFIX) + def test_list_saml_provider_configs(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, LIST_SAML_PROVIDER_CONFIGS_RESPONSE) + + page = client.list_saml_provider_configs() + + assert isinstance(page, auth.ListProviderConfigsPage) + index = 0 + assert len(page.provider_configs) == 2 + for provider_config in page.provider_configs: + self._assert_saml_provider_config( + provider_config, want_id='saml.provider{0}'.format(index)) + index += 1 + + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format( + PROVIDER_MGT_URL_PREFIX, '/tenants/tenant-id/inboundSamlConfigs?pageSize=100') + def test_tenant_not_found(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) _instrument_user_mgt(client, 500, TENANT_NOT_FOUND_RESPONSE) @@ -794,8 +822,9 @@ def _assert_request( body = json.loads(req.body.decode()) assert body == want_body - def _assert_saml_provider_config(self, provider_config): - assert provider_config.provider_id == 'saml.provider' + def _assert_saml_provider_config(self, provider_config, want_id='saml.provider'): + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == want_id assert provider_config.display_name == 'samlProviderName' assert provider_config.enabled is True assert provider_config.idp_entity_id == 'IDP_ENTITY_ID'