diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 8f9259e1a..0e738bf86 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -523,12 +523,13 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): email, action_code_settings=action_code_settings) +# TODO: Rename to public type Client class _AuthService: """Firebase Authentication service.""" ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1/projects/' - def __init__(self, app): + def __init__(self, app, tenant_id=None): credential = app.credential.get_credential() version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) @@ -538,12 +539,21 @@ def __init__(self, app): 2. set the project ID explicitly via Firebase App options, or 3. set the project ID via the GOOGLE_CLOUD_PROJECT environment variable.""") - client = _http_client.JsonHttpClient( - credential=credential, base_url=self.ID_TOOLKIT_URL + app.project_id, + url_path = app.project_id + if tenant_id: + url_path += '/tenants/{0}'.format(tenant_id) + + http_client = _http_client.JsonHttpClient( + credential=credential, base_url=self.ID_TOOLKIT_URL + url_path, headers={'X-Client-Version': version_header}) - self._token_generator = _token_gen.TokenGenerator(app, client) + self._tenant_id = tenant_id + self._token_generator = _token_gen.TokenGenerator(app, http_client) self._token_verifier = _token_gen.TokenVerifier(app) - self._user_manager = _user_mgt.UserManager(client) + self._user_manager = _user_mgt.UserManager(http_client) + + @property + def tenant_id(self): + return self._tenant_id def create_custom_token(self, uid, developer_claims=None): return self._token_generator.create_custom_token(uid, developer_claims) diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index a31e15a0b..8d69e1db5 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -18,9 +18,12 @@ Google Cloud Identity Platform (GCIP) instance. """ +import threading + import requests import firebase_admin +from firebase_admin import auth from firebase_admin import _auth_utils from firebase_admin import _http_client from firebase_admin import _utils @@ -35,6 +38,7 @@ 'Tenant', 'TenantNotFoundError', + 'auth_for_tenant', 'create_tenant', 'delete_tenant', 'get_tenant', @@ -45,6 +49,23 @@ TenantNotFoundError = _auth_utils.TenantNotFoundError +def auth_for_tenant(tenant_id, app=None): + """Gets an Auth Client instance scoped to the given tenant ID. + + Args: + tenant_id: A tenant ID string. + app: An App instance (optional). + + Returns: + _AuthService: An _AuthService object. + + Raises: + ValueError: If the tenant ID is None, empty or not a string. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + return tenant_mgt_service.auth_for_tenant(tenant_id) + + def get_tenant(tenant_id, app=None): """Gets the tenant corresponding to the given ``tenant_id``. @@ -211,8 +232,25 @@ def __init__(self, app): credential = app.credential.get_credential() version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) base_url = '{0}/projects/{1}'.format(self.TENANT_MGT_URL, app.project_id) + self.app = app self.client = _http_client.JsonHttpClient( credential=credential, base_url=base_url, headers={'X-Client-Version': version_header}) + self.tenant_clients = {} + self.lock = threading.RLock() + + def auth_for_tenant(self, tenant_id): + """Gets an Auth Client instance scoped to the given tenant ID.""" + if not isinstance(tenant_id, str) or not tenant_id: + raise ValueError( + 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + + with self.lock: + if tenant_id in self.tenant_clients: + return self.tenant_clients[tenant_id] + + client = auth._AuthService(self.app, tenant_id=tenant_id) # pylint: disable=protected-access + self.tenant_clients[tenant_id] = client + return client def get_tenant(self, tenant_id): """Gets the tenant corresponding to the given ``tenant_id``.""" diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index 6388b262c..03ff3f0ab 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -20,6 +20,7 @@ import pytest import firebase_admin +from firebase_admin import auth from firebase_admin import exceptions from firebase_admin import tenant_mgt from tests import testutils @@ -70,9 +71,13 @@ "nextPageToken": "token" }""" +MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') +MOCK_LIST_USERS_RESPONSE = testutils.resource('list_users.json') + INVALID_TENANT_IDS = [None, '', 0, 1, True, False, list(), tuple(), dict()] INVALID_BOOLEANS = ['', 1, 0, list(), tuple(), dict()] +USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' TENANT_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' @@ -93,6 +98,15 @@ def _instrument_tenant_mgt(app, status, payload): return service, recorder +def _instrument_user_mgt(client, status, payload): + recorder = [] + user_manager = client._user_manager + user_manager._client.session.mount( + auth._AuthService.ID_TOOLKIT_URL, + testutils.MockAdapter(payload, status, recorder)) + return recorder + + class TestTenant: @pytest.mark.parametrize('data', [None, 'foo', 0, 1, True, False, list(), tuple(), dict()]) @@ -127,9 +141,10 @@ def test_tenant_optional_params(self): class TestGetTenant: @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) - def test_invalid_tenant_id(self, tenant_id): - with pytest.raises(ValueError): - tenant_mgt.delete_tenant(tenant_id) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.get_tenant(tenant_id, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid tenant ID') def test_get_tenant(self, tenant_mgt_app): _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) @@ -253,6 +268,11 @@ def test_invalid_enable_email_link_sign_in(self, enable, tenant_mgt_app): 'tenant-id', enable_email_link_sign_in=enable, app=tenant_mgt_app) assert str(excinfo.value).startswith('Invalid type for enableEmailLinkSignin') + def test_update_tenant_no_args(self, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant('tenant-id', app=tenant_mgt_app) + assert str(excinfo.value).startswith('At least one parameter must be specified for update') + def test_update_tenant(self, tenant_mgt_app): _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) tenant = tenant_mgt.update_tenant( @@ -317,9 +337,10 @@ def _assert_request(self, recorder, body, mask): class TestDeleteTenant: @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) - def test_invalid_tenant_id(self, tenant_id): - with pytest.raises(ValueError): - tenant_mgt.delete_tenant(tenant_id) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.delete_tenant(tenant_id, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid tenant ID') def test_delete_tenant(self, tenant_mgt_app): _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, '{}') @@ -475,6 +496,216 @@ def _assert_request(self, recorder, expected=None): assert request == expected +class TestAuthForTenant: + + @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError): + tenant_mgt.auth_for_tenant(tenant_id, app=tenant_mgt_app) + + def test_client(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant1', app=tenant_mgt_app) + assert client.tenant_id == 'tenant1' + + def test_client_reuse(self, tenant_mgt_app): + client1 = tenant_mgt.auth_for_tenant('tenant1', app=tenant_mgt_app) + client2 = tenant_mgt.auth_for_tenant('tenant1', app=tenant_mgt_app) + client3 = tenant_mgt.auth_for_tenant('tenant2', app=tenant_mgt_app) + assert client1 is client2 + assert client1 is not client3 + + +class TestTenantAwareUserManagement: + + def test_get_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_GET_USER_RESPONSE) + + user = client.get_user('testuser') + + assert isinstance(user, auth.UserRecord) + assert user.uid == 'testuser' + assert user.email == 'testuser@example.com' + self._assert_request(recorder, '/accounts:lookup', {'localId': ['testuser']}) + + def test_get_user_by_email(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_GET_USER_RESPONSE) + + user = client.get_user_by_email('testuser@example.com') + + assert isinstance(user, auth.UserRecord) + assert user.uid == 'testuser' + assert user.email == 'testuser@example.com' + self._assert_request(recorder, '/accounts:lookup', {'email': ['testuser@example.com']}) + + def test_get_user_by_phone_number(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_GET_USER_RESPONSE) + + user = client.get_user_by_phone_number('+1234567890') + + assert isinstance(user, auth.UserRecord) + assert user.uid == 'testuser' + assert user.email == 'testuser@example.com' + self._assert_request(recorder, '/accounts:lookup', {'phoneNumber': ['+1234567890']}) + + def test_create_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + + uid = client._user_manager.create_user() + + assert uid == 'testuser' + self._assert_request(recorder, '/accounts', {}) + + def test_update_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + + uid = client._user_manager.update_user('testuser', email='testuser@example.com') + + assert uid == 'testuser' + self._assert_request(recorder, '/accounts:update', { + 'localId': 'testuser', + 'email': 'testuser@example.com', + }) + + def test_delete_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"kind":"deleteresponse"}') + + client.delete_user('testuser') + + self._assert_request(recorder, '/accounts:delete', {'localId': 'testuser'}) + + def test_set_custom_user_claims(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + claims = {'admin': True} + + client.set_custom_user_claims('testuser', claims) + + self._assert_request(recorder, '/accounts:update', { + 'localId': 'testuser', + 'customAttributes': json.dumps(claims), + }) + + def test_revoke_refresh_tokens(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + + client.revoke_refresh_tokens('testuser') + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/tenants/tenant-id/accounts:update'.format( + USER_MGT_URL_PREFIX) + body = json.loads(req.body.decode()) + assert body['localId'] == 'testuser' + assert 'validSince' in body + + def test_list_users(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_LIST_USERS_RESPONSE) + + page = client.list_users() + + assert isinstance(page, auth.ListUsersPage) + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + users = list(user for user in page.iterate_all()) + assert len(users) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/tenants/tenant-id/accounts:batchGet?maxResults=1000'.format( + USER_MGT_URL_PREFIX) + + def test_import_users(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{}') + users = [ + auth.ImportUserRecord(uid='user1'), + auth.ImportUserRecord(uid='user2'), + ] + + result = client.import_users(users) + + assert isinstance(result, auth.UserImportResult) + assert result.success_count == 2 + assert result.failure_count == 0 + assert result.errors == [] + self._assert_request(recorder, '/accounts:batchCreate', { + 'users': [{'localId': 'user1'}, {'localId': 'user2'}], + }) + + def test_generate_password_reset_link(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"oobLink":"https://testlink"}') + + link = client.generate_password_reset_link('test@test.com') + + assert link == 'https://testlink' + self._assert_request(recorder, '/accounts:sendOobCode', { + 'email': 'test@test.com', + 'requestType': 'PASSWORD_RESET', + 'returnOobLink': True, + }) + + def test_generate_email_verification_link(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"oobLink":"https://testlink"}') + + link = client.generate_email_verification_link('test@test.com') + + assert link == 'https://testlink' + self._assert_request(recorder, '/accounts:sendOobCode', { + 'email': 'test@test.com', + 'requestType': 'VERIFY_EMAIL', + 'returnOobLink': True, + }) + + def test_generate_sign_in_with_email_link(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"oobLink":"https://testlink"}') + settings = auth.ActionCodeSettings(url='http://localhost') + + link = client.generate_sign_in_with_email_link('test@test.com', settings) + + assert link == 'https://testlink' + self._assert_request(recorder, '/accounts:sendOobCode', { + 'email': 'test@test.com', + 'requestType': 'EMAIL_SIGNIN', + 'returnOobLink': True, + 'continueUrl': 'http://localhost', + }) + + def test_tenant_not_found(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + _instrument_user_mgt(client, 500, TENANT_NOT_FOUND_RESPONSE) + with pytest.raises(tenant_mgt.TenantNotFoundError) as excinfo: + client.get_user('testuser') + + error_msg = 'No tenant found for the given identifier (TENANT_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + + def _assert_request(self, recorder, want_url, want_body): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/tenants/tenant-id{1}'.format(USER_MGT_URL_PREFIX, want_url) + body = json.loads(req.body.decode()) + assert body == want_body + + def _assert_tenant(tenant, tenant_id='tenant-id'): assert isinstance(tenant, tenant_mgt.Tenant) assert tenant.tenant_id == tenant_id