Skip to content

Commit 064d9dc

Browse files
waltaskewolavloite
andauthored
feat: support fine-grained permissions database roles in connect (#1338)
* feat: support fine-grained permissions database roles in connect Add an optional `database_role` argument to `connect` for supplying the database role to connect as when using [fine-grained access controls](https://cloud.google.com/spanner/docs/access-with-fgac) * feat: support fine-grained permissions database roles in connect Add an optional `database_role` argument to `connect` for supplying the database role to connect as when using [fine-grained access controls](https://cloud.google.com/spanner/docs/access-with-fgac) * add missing newline to code block --------- Co-authored-by: Knut Olav Løite <[email protected]>
1 parent 686bda6 commit 064d9dc

File tree

5 files changed

+48
-18
lines changed

5 files changed

+48
-18
lines changed

README.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,13 @@ Connection API represents a wrap-around for Python Spanner API, written in accor
252252
result = cursor.fetchall()
253253
254254
255+
If using [fine-grained access controls](https://cloud.google.com/spanner/docs/access-with-fgac) you can pass a ``database_role`` argument to connect as that role:
256+
257+
.. code:: python
258+
259+
connection = connect("instance-id", "database-id", database_role='your-role')
260+
261+
255262
Aborted Transactions Retry Mechanism
256263
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
257264

google/cloud/spanner_dbapi/connection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,7 @@ def connect(
722722
user_agent=None,
723723
client=None,
724724
route_to_leader_enabled=True,
725+
database_role=None,
725726
**kwargs,
726727
):
727728
"""Creates a connection to a Google Cloud Spanner database.
@@ -765,6 +766,10 @@ def connect(
765766
disable leader aware routing. Disabling leader aware routing would
766767
route all requests in RW/PDML transactions to the closest region.
767768
769+
:type database_role: str
770+
:param database_role: (Optional) The database role to connect as when using
771+
fine-grained access controls.
772+
768773
**kwargs: Initial value for connection variables.
769774
770775
@@ -803,7 +808,9 @@ def connect(
803808
instance = client.instance(instance_id)
804809
database = None
805810
if database_id:
806-
database = instance.database(database_id, pool=pool)
811+
database = instance.database(
812+
database_id, pool=pool, database_role=database_role
813+
)
807814
conn = Connection(instance, database, **kwargs)
808815
if pool is not None:
809816
conn._own_pool = False

tests/system/test_dbapi.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -865,9 +865,9 @@ def test_execute_batch_dml_abort_retry(self, dbapi_database):
865865
self._cursor.execute("run batch")
866866
dbapi_database._method_abort_interceptor.reset()
867867
self._conn.commit()
868-
assert method_count_interceptor._counts[COMMIT_METHOD] == 1
869-
assert method_count_interceptor._counts[EXECUTE_BATCH_DML_METHOD] == 3
870-
assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 6
868+
assert method_count_interceptor._counts[COMMIT_METHOD] >= 1
869+
assert method_count_interceptor._counts[EXECUTE_BATCH_DML_METHOD] >= 3
870+
assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] >= 6
871871

872872
self._cursor.execute("SELECT * FROM contacts")
873873
got_rows = self._cursor.fetchall()
@@ -879,28 +879,28 @@ def test_multiple_aborts_in_transaction(self, dbapi_database):
879879

880880
method_count_interceptor = dbapi_database._method_count_interceptor
881881
method_count_interceptor.reset()
882-
# called 3 times
882+
# called at least 3 times
883883
self._insert_row(1)
884884
dbapi_database._method_abort_interceptor.set_method_to_abort(
885885
EXECUTE_STREAMING_SQL_METHOD, self._conn
886886
)
887-
# called 3 times
887+
# called at least 3 times
888888
self._cursor.execute("SELECT * FROM contacts")
889889
dbapi_database._method_abort_interceptor.reset()
890890
self._cursor.fetchall()
891-
# called 2 times
891+
# called at least 2 times
892892
self._insert_row(2)
893-
# called 2 times
893+
# called at least 2 times
894894
self._cursor.execute("SELECT * FROM contacts")
895895
self._cursor.fetchone()
896896
dbapi_database._method_abort_interceptor.set_method_to_abort(
897897
COMMIT_METHOD, self._conn
898898
)
899-
# called 2 times
899+
# called at least 2 times
900900
self._conn.commit()
901901
dbapi_database._method_abort_interceptor.reset()
902-
assert method_count_interceptor._counts[COMMIT_METHOD] == 2
903-
assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 10
902+
assert method_count_interceptor._counts[COMMIT_METHOD] >= 2
903+
assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] >= 10
904904

905905
self._cursor.execute("SELECT * FROM contacts")
906906
got_rows = self._cursor.fetchall()
@@ -921,8 +921,8 @@ def test_consecutive_aborted_transactions(self, dbapi_database):
921921
)
922922
self._conn.commit()
923923
dbapi_database._method_abort_interceptor.reset()
924-
assert method_count_interceptor._counts[COMMIT_METHOD] == 2
925-
assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 6
924+
assert method_count_interceptor._counts[COMMIT_METHOD] >= 2
925+
assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] >= 6
926926

927927
method_count_interceptor = dbapi_database._method_count_interceptor
928928
method_count_interceptor.reset()
@@ -935,8 +935,8 @@ def test_consecutive_aborted_transactions(self, dbapi_database):
935935
)
936936
self._conn.commit()
937937
dbapi_database._method_abort_interceptor.reset()
938-
assert method_count_interceptor._counts[COMMIT_METHOD] == 2
939-
assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 6
938+
assert method_count_interceptor._counts[COMMIT_METHOD] >= 2
939+
assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] >= 6
940940

941941
self._cursor.execute("SELECT * FROM contacts")
942942
got_rows = self._cursor.fetchall()

tests/unit/spanner_dbapi/test_connect.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def test_w_implicit(self, mock_client):
6666
)
6767

6868
self.assertIs(connection.database, database)
69-
instance.database.assert_called_once_with(DATABASE, pool=None)
69+
instance.database.assert_called_once_with(
70+
DATABASE, pool=None, database_role=None
71+
)
7072
# Datbase constructs its own pool
7173
self.assertIsNotNone(connection.database._pool)
7274
self.assertTrue(connection.instance._client.route_to_leader_enabled)
@@ -82,13 +84,15 @@ def test_w_explicit(self, mock_client):
8284
client = mock_client.return_value
8385
instance = client.instance.return_value
8486
database = instance.database.return_value
87+
role = "some_role"
8588

8689
connection = connect(
8790
INSTANCE,
8891
DATABASE,
8992
PROJECT,
9093
credentials,
9194
pool=pool,
95+
database_role=role,
9296
user_agent=USER_AGENT,
9397
route_to_leader_enabled=False,
9498
)
@@ -110,7 +114,9 @@ def test_w_explicit(self, mock_client):
110114
client.instance.assert_called_once_with(INSTANCE)
111115

112116
self.assertIs(connection.database, database)
113-
instance.database.assert_called_once_with(DATABASE, pool=pool)
117+
instance.database.assert_called_once_with(
118+
DATABASE, pool=pool, database_role=role
119+
)
114120

115121
def test_w_credential_file_path(self, mock_client):
116122
from google.cloud.spanner_dbapi import connect

tests/unit/spanner_dbapi/test_connection.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,13 @@ def test_custom_client_connection(self):
826826
connection = connect("test-instance", "test-database", client=client)
827827
self.assertTrue(connection.instance._client == client)
828828

829+
def test_custom_database_role(self):
830+
from google.cloud.spanner_dbapi import connect
831+
832+
role = "some_role"
833+
connection = connect("test-instance", "test-database", database_role=role)
834+
self.assertEqual(connection.database.database_role, role)
835+
829836
def test_invalid_custom_client_connection(self):
830837
from google.cloud.spanner_dbapi import connect
831838

@@ -874,8 +881,9 @@ def database(
874881
database_id="database_id",
875882
pool=None,
876883
database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL,
884+
database_role=None,
877885
):
878-
return _Database(database_id, pool, database_dialect)
886+
return _Database(database_id, pool, database_dialect, database_role)
879887

880888

881889
class _Database(object):
@@ -884,7 +892,9 @@ def __init__(
884892
database_id="database_id",
885893
pool=None,
886894
database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL,
895+
database_role=None,
887896
):
888897
self.name = database_id
889898
self.pool = pool
890899
self.database_dialect = database_dialect
900+
self.database_role = database_role

0 commit comments

Comments
 (0)