Skip to content

Commit aa36b07

Browse files
Sunny Singhharshachinta
andauthored
feat: Batch Write API implementation and samples (#1027)
* feat: Batch Write API implementation and samples * Update sample * review comments * return public class for mutation groups * Update google/cloud/spanner_v1/batch.py Co-authored-by: Sri Harsha CH <[email protected]> * Update google/cloud/spanner_v1/batch.py Co-authored-by: Sri Harsha CH <[email protected]> * review comments * remove doc * feat(spanner): nit sample data refactoring * review comments * fix test --------- Co-authored-by: Sri Harsha CH <[email protected]> Co-authored-by: Sri Harsha CH <[email protected]>
1 parent 7debe71 commit aa36b07

File tree

9 files changed

+584
-0
lines changed

9 files changed

+584
-0
lines changed

google/cloud/spanner_v1/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from .types.result_set import ResultSetStats
3535
from .types.spanner import BatchCreateSessionsRequest
3636
from .types.spanner import BatchCreateSessionsResponse
37+
from .types.spanner import BatchWriteRequest
38+
from .types.spanner import BatchWriteResponse
3739
from .types.spanner import BeginTransactionRequest
3840
from .types.spanner import CommitRequest
3941
from .types.spanner import CreateSessionRequest
@@ -99,6 +101,8 @@
99101
# google.cloud.spanner_v1.types
100102
"BatchCreateSessionsRequest",
101103
"BatchCreateSessionsResponse",
104+
"BatchWriteRequest",
105+
"BatchWriteResponse",
102106
"BeginTransactionRequest",
103107
"CommitRequest",
104108
"CommitResponse",

google/cloud/spanner_v1/batch.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from google.cloud.spanner_v1 import CommitRequest
1919
from google.cloud.spanner_v1 import Mutation
2020
from google.cloud.spanner_v1 import TransactionOptions
21+
from google.cloud.spanner_v1 import BatchWriteRequest
2122

2223
from google.cloud.spanner_v1._helpers import _SessionWrapper
2324
from google.cloud.spanner_v1._helpers import _make_list_value_pbs
@@ -215,6 +216,99 @@ def __exit__(self, exc_type, exc_val, exc_tb):
215216
self.commit()
216217

217218

219+
class MutationGroup(_BatchBase):
220+
"""A container for mutations.
221+
222+
Clients should use :class:`~google.cloud.spanner_v1.MutationGroups` to
223+
obtain instances instead of directly creating instances.
224+
225+
:type session: :class:`~google.cloud.spanner_v1.session.Session`
226+
:param session: The session used to perform the commit.
227+
228+
:type mutations: list
229+
:param mutations: The list into which mutations are to be accumulated.
230+
"""
231+
232+
def __init__(self, session, mutations=[]):
233+
super(MutationGroup, self).__init__(session)
234+
self._mutations = mutations
235+
236+
237+
class MutationGroups(_SessionWrapper):
238+
"""Accumulate mutation groups for transmission during :meth:`batch_write`.
239+
240+
:type session: :class:`~google.cloud.spanner_v1.session.Session`
241+
:param session: the session used to perform the commit
242+
"""
243+
244+
committed = None
245+
246+
def __init__(self, session):
247+
super(MutationGroups, self).__init__(session)
248+
self._mutation_groups = []
249+
250+
def _check_state(self):
251+
"""Checks if the object's state is valid for making API requests.
252+
253+
:raises: :exc:`ValueError` if the object's state is invalid for making
254+
API requests.
255+
"""
256+
if self.committed is not None:
257+
raise ValueError("MutationGroups already committed")
258+
259+
def group(self):
260+
"""Returns a new `MutationGroup` to which mutations can be added."""
261+
mutation_group = BatchWriteRequest.MutationGroup()
262+
self._mutation_groups.append(mutation_group)
263+
return MutationGroup(self._session, mutation_group.mutations)
264+
265+
def batch_write(self, request_options=None):
266+
"""Executes batch_write.
267+
268+
:type request_options:
269+
:class:`google.cloud.spanner_v1.types.RequestOptions`
270+
:param request_options:
271+
(Optional) Common options for this request.
272+
If a dict is provided, it must be of the same form as the protobuf
273+
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
274+
275+
:rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]`
276+
:returns: a sequence of responses for each batch.
277+
"""
278+
self._check_state()
279+
280+
database = self._session._database
281+
api = database.spanner_api
282+
metadata = _metadata_with_prefix(database.name)
283+
if database._route_to_leader_enabled:
284+
metadata.append(
285+
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
286+
)
287+
trace_attributes = {"num_mutation_groups": len(self._mutation_groups)}
288+
if request_options is None:
289+
request_options = RequestOptions()
290+
elif type(request_options) is dict:
291+
request_options = RequestOptions(request_options)
292+
293+
request = BatchWriteRequest(
294+
session=self._session.name,
295+
mutation_groups=self._mutation_groups,
296+
request_options=request_options,
297+
)
298+
with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes):
299+
method = functools.partial(
300+
api.batch_write,
301+
request=request,
302+
metadata=metadata,
303+
)
304+
response = _retry(
305+
method,
306+
allowed_exceptions={InternalServerError: _check_rst_stream_error},
307+
)
308+
self.committed = True
309+
return response
310+
311+
218312
def _make_write_pb(table, columns, values):
219313
"""Helper for :meth:`Batch.insert` et al.
220314

google/cloud/spanner_v1/database.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
_metadata_with_leader_aware_routing,
5151
)
5252
from google.cloud.spanner_v1.batch import Batch
53+
from google.cloud.spanner_v1.batch import MutationGroups
5354
from google.cloud.spanner_v1.keyset import KeySet
5455
from google.cloud.spanner_v1.pool import BurstyPool
5556
from google.cloud.spanner_v1.pool import SessionCheckout
@@ -734,6 +735,17 @@ def batch(self, request_options=None):
734735
"""
735736
return BatchCheckout(self, request_options)
736737

738+
def mutation_groups(self):
739+
"""Return an object which wraps a mutation_group.
740+
741+
The wrapper *must* be used as a context manager, with the mutation group
742+
as the value returned by the wrapper.
743+
744+
:rtype: :class:`~google.cloud.spanner_v1.database.MutationGroupsCheckout`
745+
:returns: new wrapper
746+
"""
747+
return MutationGroupsCheckout(self)
748+
737749
def batch_snapshot(self, read_timestamp=None, exact_staleness=None):
738750
"""Return an object which wraps a batch read / query.
739751
@@ -1040,6 +1052,39 @@ def __exit__(self, exc_type, exc_val, exc_tb):
10401052
self._database._pool.put(self._session)
10411053

10421054

1055+
class MutationGroupsCheckout(object):
1056+
"""Context manager for using mutation groups from a database.
1057+
1058+
Inside the context manager, checks out a session from the database,
1059+
creates mutation groups from it, making the groups available.
1060+
1061+
Caller must *not* use the object to perform API requests outside the scope
1062+
of the context manager.
1063+
1064+
:type database: :class:`~google.cloud.spanner_v1.database.Database`
1065+
:param database: database to use
1066+
"""
1067+
1068+
def __init__(self, database):
1069+
self._database = database
1070+
self._session = None
1071+
1072+
def __enter__(self):
1073+
"""Begin ``with`` block."""
1074+
session = self._session = self._database._pool.get()
1075+
return MutationGroups(session)
1076+
1077+
def __exit__(self, exc_type, exc_val, exc_tb):
1078+
"""End ``with`` block."""
1079+
if isinstance(exc_val, NotFound):
1080+
# If NotFound exception occurs inside the with block
1081+
# then we validate if the session still exists.
1082+
if not self._session.exists():
1083+
self._session = self._database._pool._new_session()
1084+
self._session.create()
1085+
self._database._pool.put(self._session)
1086+
1087+
10431088
class SnapshotCheckout(object):
10441089
"""Context manager for using a snapshot from a database.
10451090

samples/samples/snippets.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,65 @@ def insert_data(instance_id, database_id):
403403
# [END spanner_insert_data]
404404

405405

406+
# [START spanner_batch_write_at_least_once]
407+
def batch_write(instance_id, database_id):
408+
"""Inserts sample data into the given database via BatchWrite API.
409+
410+
The database and table must already exist and can be created using
411+
`create_database`.
412+
"""
413+
from google.rpc.code_pb2 import OK
414+
415+
spanner_client = spanner.Client()
416+
instance = spanner_client.instance(instance_id)
417+
database = instance.database(database_id)
418+
419+
with database.mutation_groups() as groups:
420+
group1 = groups.group()
421+
group1.insert_or_update(
422+
table="Singers",
423+
columns=("SingerId", "FirstName", "LastName"),
424+
values=[
425+
(16, "Scarlet", "Terry"),
426+
],
427+
)
428+
429+
group2 = groups.group()
430+
group2.insert_or_update(
431+
table="Singers",
432+
columns=("SingerId", "FirstName", "LastName"),
433+
values=[
434+
(17, "Marc", ""),
435+
(18, "Catalina", "Smith"),
436+
],
437+
)
438+
group2.insert_or_update(
439+
table="Albums",
440+
columns=("SingerId", "AlbumId", "AlbumTitle"),
441+
values=[
442+
(17, 1, "Total Junk"),
443+
(18, 2, "Go, Go, Go"),
444+
],
445+
)
446+
447+
for response in groups.batch_write():
448+
if response.status.code == OK:
449+
print(
450+
"Mutation group indexes {} have been applied with commit timestamp {}".format(
451+
response.indexes, response.commit_timestamp
452+
)
453+
)
454+
else:
455+
print(
456+
"Mutation group indexes {} could not be applied with error {}".format(
457+
response.indexes, response.status
458+
)
459+
)
460+
461+
462+
# [END spanner_batch_write_at_least_once]
463+
464+
406465
# [START spanner_delete_data]
407466
def delete_data(instance_id, database_id):
408467
"""Deletes sample data from the given database.
@@ -2677,6 +2736,7 @@ def drop_sequence(instance_id, database_id):
26772736
subparsers.add_parser("create_instance", help=create_instance.__doc__)
26782737
subparsers.add_parser("create_database", help=create_database.__doc__)
26792738
subparsers.add_parser("insert_data", help=insert_data.__doc__)
2739+
subparsers.add_parser("batch_write", help=batch_write.__doc__)
26802740
subparsers.add_parser("delete_data", help=delete_data.__doc__)
26812741
subparsers.add_parser("query_data", help=query_data.__doc__)
26822742
subparsers.add_parser("read_data", help=read_data.__doc__)
@@ -2811,6 +2871,8 @@ def drop_sequence(instance_id, database_id):
28112871
create_database(args.instance_id, args.database_id)
28122872
elif args.command == "insert_data":
28132873
insert_data(args.instance_id, args.database_id)
2874+
elif args.command == "batch_write":
2875+
batch_write(args.instance_id, args.database_id)
28142876
elif args.command == "delete_data":
28152877
delete_data(args.instance_id, args.database_id)
28162878
elif args.command == "query_data":

samples/samples/snippets_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,13 @@ def test_insert_data(capsys, instance_id, sample_database):
290290
assert "Inserted data" in out
291291

292292

293+
@pytest.mark.dependency(name="batch_write")
294+
def test_batch_write(capsys, instance_id, sample_database):
295+
snippets.batch_write(instance_id, sample_database.database_id)
296+
out, _ = capsys.readouterr()
297+
assert "could not be applied with error" not in out
298+
299+
293300
@pytest.mark.dependency(depends=["insert_data"])
294301
def test_delete_data(capsys, instance_id, sample_database):
295302
snippets.delete_data(instance_id, sample_database.database_id)

tests/system/_sample_data.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@
2727
(2, "Bharney", "Rhubble", "[email protected]"),
2828
(3, "Wylma", "Phlyntstone", "[email protected]"),
2929
)
30+
BATCH_WRITE_ROW_DATA = (
31+
(1, "Phred", "Phlyntstone", "[email protected]"),
32+
(2, "Bharney", "Rhubble", "[email protected]"),
33+
(3, "Wylma", "Phlyntstone", "[email protected]"),
34+
(4, "Pebbles", "Phlyntstone", "[email protected]"),
35+
(5, "Betty", "Rhubble", "[email protected]"),
36+
(6, "Slate", "Stephenson", "[email protected]"),
37+
)
3038
ALL = spanner_v1.KeySet(all_=True)
3139
SQL = "SELECT * FROM contacts ORDER BY contact_id"
3240

tests/system/test_session_api.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,41 @@ def test_partition_query(sessions_database, not_emulator):
25212521
batch_txn.close()
25222522

25232523

2524+
def test_mutation_groups_insert_or_update_then_query(not_emulator, sessions_database):
2525+
sd = _sample_data
2526+
num_groups = 3
2527+
num_mutations_per_group = len(sd.BATCH_WRITE_ROW_DATA) // num_groups
2528+
2529+
with sessions_database.batch() as batch:
2530+
batch.delete(sd.TABLE, sd.ALL)
2531+
2532+
with sessions_database.mutation_groups() as groups:
2533+
for i in range(num_groups):
2534+
group = groups.group()
2535+
for j in range(num_mutations_per_group):
2536+
group.insert_or_update(
2537+
sd.TABLE,
2538+
sd.COLUMNS,
2539+
[sd.BATCH_WRITE_ROW_DATA[i * num_mutations_per_group + j]],
2540+
)
2541+
# Response indexes received
2542+
seen = collections.Counter()
2543+
for response in groups.batch_write():
2544+
_check_batch_status(response.status.code)
2545+
assert response.commit_timestamp is not None
2546+
assert len(response.indexes) > 0
2547+
seen.update(response.indexes)
2548+
# All indexes must be in the range [0, num_groups-1] and seen exactly once
2549+
assert len(seen) == num_groups
2550+
assert all((0 <= idx < num_groups and ct == 1) for (idx, ct) in seen.items())
2551+
2552+
# Verify the writes by reading from the database
2553+
with sessions_database.snapshot() as snapshot:
2554+
rows = list(snapshot.execute_sql(sd.SQL))
2555+
2556+
sd._check_rows_data(rows, sd.BATCH_WRITE_ROW_DATA)
2557+
2558+
25242559
class FauxCall:
25252560
def __init__(self, code, details="FauxCall"):
25262561
self._code = code

0 commit comments

Comments
 (0)