Skip to content

Commit faf8b83

Browse files
fix: Make download_ranges compatible with asyncio.create_task(..) (#1591)
fix: Make `download_ranges` compatible with `asyncio.create_task(..)` (#1591) --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 7fee3dd commit faf8b83

File tree

3 files changed

+179
-32
lines changed

3 files changed

+179
-32
lines changed

google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from __future__ import annotations
16+
import asyncio
1617
import google_crc32c
1718
from google.api_core import exceptions
1819
from google_crc32c import Checksum
@@ -29,6 +30,7 @@
2930
from io import BytesIO
3031
from google.cloud import _storage_v2
3132
from google.cloud.storage.exceptions import DataCorruption
33+
from google.cloud.storage._helpers import generate_random_56_bit_integer
3234

3335

3436
_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100
@@ -78,7 +80,7 @@ class AsyncMultiRangeDownloader:
7880
my_buff2 = BytesIO()
7981
my_buff3 = BytesIO()
8082
my_buff4 = any_object_which_provides_BytesIO_like_interface()
81-
results_arr = await mrd.download_ranges(
83+
await mrd.download_ranges(
8284
[
8385
# (start_byte, bytes_to_read, writeable_buffer)
8486
(0, 100, my_buff1),
@@ -88,8 +90,8 @@ class AsyncMultiRangeDownloader:
8890
]
8991
)
9092
91-
for result in results_arr:
92-
print("downloaded bytes", result)
93+
# verify data in buffers...
94+
assert my_buff2.getbuffer().nbytes == 20
9395
9496
9597
"""
@@ -175,6 +177,10 @@ def __init__(
175177
self.read_obj_str: Optional[_AsyncReadObjectStream] = None
176178
self._is_stream_open: bool = False
177179

180+
self._read_id_to_writable_buffer_dict = {}
181+
self._read_id_to_download_ranges_id = {}
182+
self._download_ranges_id_to_pending_read_ids = {}
183+
178184
async def open(self) -> None:
179185
"""Opens the bidi-gRPC connection to read from the object.
180186
@@ -203,8 +209,8 @@ async def open(self) -> None:
203209
return
204210

205211
async def download_ranges(
206-
self, read_ranges: List[Tuple[int, int, BytesIO]]
207-
) -> List[Result]:
212+
self, read_ranges: List[Tuple[int, int, BytesIO]], lock: asyncio.Lock = None
213+
) -> None:
208214
"""Downloads multiple byte ranges from the object into the buffers
209215
provided by user.
210216
@@ -214,9 +220,36 @@ async def download_ranges(
214220
to be provided by the user, and user has to make sure appropriate
215221
memory is available in the application to avoid out-of-memory crash.
216222
217-
:rtype: List[:class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.Result`]
218-
:returns: A list of ``Result`` objects, where each object corresponds
219-
to a requested range.
223+
:type lock: asyncio.Lock
224+
:param lock: (Optional) An asyncio lock to synchronize sends and recvs
225+
on the underlying bidi-GRPC stream. This is required when multiple
226+
coroutines are calling this method concurrently.
227+
228+
i.e. Example usage with multiple coroutines:
229+
230+
```
231+
lock = asyncio.Lock()
232+
task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock))
233+
task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock))
234+
await asyncio.gather(task1, task2)
235+
236+
```
237+
238+
If user want to call this method serially from multiple coroutines,
239+
then providing a lock is not necessary.
240+
241+
```
242+
await mrd.download_ranges(ranges1)
243+
await mrd.download_ranges(ranges2)
244+
245+
# ... some other code code...
246+
247+
```
248+
249+
250+
:raises ValueError: if the underlying bidi-GRPC stream is not open.
251+
:raises ValueError: if the length of read_ranges is more than 1000.
252+
:raises DataCorruption: if a checksum mismatch is detected while reading data.
220253
221254
"""
222255

@@ -228,32 +261,43 @@ async def download_ranges(
228261
if not self._is_stream_open:
229262
raise ValueError("Underlying bidi-gRPC stream is not open")
230263

231-
read_id_to_writable_buffer_dict = {}
232-
results = []
264+
if lock is None:
265+
lock = asyncio.Lock()
266+
267+
_func_id = generate_random_56_bit_integer()
268+
read_ids_in_current_func = set()
233269
for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST):
234270
read_ranges_segment = read_ranges[
235271
i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
236272
]
237273

238274
read_ranges_for_bidi_req = []
239275
for j, read_range in enumerate(read_ranges_segment):
240-
read_id = i + j
241-
read_id_to_writable_buffer_dict[read_id] = read_range[2]
276+
read_id = generate_random_56_bit_integer()
277+
read_ids_in_current_func.add(read_id)
278+
self._read_id_to_download_ranges_id[read_id] = _func_id
279+
self._read_id_to_writable_buffer_dict[read_id] = read_range[2]
242280
bytes_requested = read_range[1]
243-
results.append(Result(bytes_requested))
244281
read_ranges_for_bidi_req.append(
245282
_storage_v2.ReadRange(
246283
read_offset=read_range[0],
247284
read_length=bytes_requested,
248285
read_id=read_id,
249286
)
250287
)
251-
await self.read_obj_str.send(
252-
_storage_v2.BidiReadObjectRequest(read_ranges=read_ranges_for_bidi_req)
253-
)
288+
async with lock:
289+
await self.read_obj_str.send(
290+
_storage_v2.BidiReadObjectRequest(
291+
read_ranges=read_ranges_for_bidi_req
292+
)
293+
)
294+
self._download_ranges_id_to_pending_read_ids[
295+
_func_id
296+
] = read_ids_in_current_func
254297

255-
while len(read_id_to_writable_buffer_dict) > 0:
256-
response = await self.read_obj_str.recv()
298+
while len(self._download_ranges_id_to_pending_read_ids[_func_id]) > 0:
299+
async with lock:
300+
response = await self.read_obj_str.recv()
257301

258302
if response is None:
259303
raise Exception("None response received, something went wrong.")
@@ -277,16 +321,15 @@ async def download_ranges(
277321
)
278322

279323
read_id = object_data_range.read_range.read_id
280-
buffer = read_id_to_writable_buffer_dict[read_id]
324+
buffer = self._read_id_to_writable_buffer_dict[read_id]
281325
buffer.write(data)
282-
results[read_id].bytes_written += len(data)
283326

284327
if object_data_range.range_end:
285-
del read_id_to_writable_buffer_dict[
286-
object_data_range.read_range.read_id
287-
]
288-
289-
return results
328+
tmp_dn_ranges_id = self._read_id_to_download_ranges_id[read_id]
329+
self._download_ranges_id_to_pending_read_ids[
330+
tmp_dn_ranges_id
331+
].remove(read_id)
332+
del self._read_id_to_download_ranges_id[read_id]
290333

291334
async def close(self):
292335
"""

google/cloud/storage/_helpers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from hashlib import md5
2323
import os
2424
import sys
25+
import secrets
2526
from urllib.parse import urlsplit
2627
from urllib.parse import urlunsplit
2728
from uuid import uuid4
@@ -668,3 +669,20 @@ def _get_default_headers(
668669
"content-type": content_type,
669670
"x-upload-content-type": x_upload_content_type or content_type,
670671
}
672+
673+
674+
def generate_random_56_bit_integer():
675+
"""Generates a secure 56 bit random integer.
676+
677+
678+
If 64 bit int is used, sometimes the random int generated is greater than
679+
max positive value of signed 64 bit int which is 2^63 -1 causing overflow
680+
issues.
681+
682+
:rtype: int
683+
:returns: A secure random 56 bit integer.
684+
"""
685+
# 7 bytes * 8 bits/byte = 56 bits
686+
random_bytes = secrets.token_bytes(7)
687+
# Convert bytes to an integer
688+
return int.from_bytes(random_bytes, "big")

tests/unit/asyncio/test_async_multi_range_downloader.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
1516
import pytest
1617
from unittest import mock
1718
from unittest.mock import AsyncMock
@@ -107,6 +108,93 @@ async def test_create_mrd(
107108
assert mrd.read_handle == _TEST_READ_HANDLE
108109
assert mrd.is_stream_open
109110

111+
@mock.patch(
112+
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader.generate_random_56_bit_integer"
113+
)
114+
@mock.patch(
115+
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream"
116+
)
117+
@mock.patch(
118+
"google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client"
119+
)
120+
@pytest.mark.asyncio
121+
async def test_download_ranges_via_async_gather(
122+
self, mock_grpc_client, mock_cls_async_read_object_stream, mock_random_int
123+
):
124+
# Arrange
125+
data = b"these_are_18_chars"
126+
crc32c = Checksum(data).digest()
127+
crc32c_int = int.from_bytes(crc32c, "big")
128+
crc32c_checksum_for_data_slice = int.from_bytes(
129+
Checksum(data[10:16]).digest(), "big"
130+
)
131+
132+
mock_mrd = await self._make_mock_mrd(
133+
mock_grpc_client, mock_cls_async_read_object_stream
134+
)
135+
mock_random_int.side_effect = [123, 456, 789, 91011] # for _func_id and read_id
136+
mock_mrd.read_obj_str.send = AsyncMock()
137+
mock_mrd.read_obj_str.recv = AsyncMock()
138+
139+
mock_mrd.read_obj_str.recv.side_effect = [
140+
_storage_v2.BidiReadObjectResponse(
141+
object_data_ranges=[
142+
_storage_v2.ObjectRangeData(
143+
checksummed_data=_storage_v2.ChecksummedData(
144+
content=data, crc32c=crc32c_int
145+
),
146+
range_end=True,
147+
read_range=_storage_v2.ReadRange(
148+
read_offset=0, read_length=18, read_id=456
149+
),
150+
)
151+
]
152+
),
153+
_storage_v2.BidiReadObjectResponse(
154+
object_data_ranges=[
155+
_storage_v2.ObjectRangeData(
156+
checksummed_data=_storage_v2.ChecksummedData(
157+
content=data[10:16],
158+
crc32c=crc32c_checksum_for_data_slice,
159+
),
160+
range_end=True,
161+
read_range=_storage_v2.ReadRange(
162+
read_offset=10, read_length=6, read_id=91011
163+
),
164+
)
165+
],
166+
),
167+
]
168+
169+
# Act
170+
buffer = BytesIO()
171+
second_buffer = BytesIO()
172+
lock = asyncio.Lock()
173+
task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)], lock))
174+
task2 = asyncio.create_task(
175+
mock_mrd.download_ranges([(10, 6, second_buffer)], lock)
176+
)
177+
await asyncio.gather(task1, task2)
178+
179+
# Assert
180+
mock_mrd.read_obj_str.send.side_effect = [
181+
_storage_v2.BidiReadObjectRequest(
182+
read_ranges=[
183+
_storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456)
184+
]
185+
),
186+
_storage_v2.BidiReadObjectRequest(
187+
read_ranges=[
188+
_storage_v2.ReadRange(read_offset=10, read_length=6, read_id=91011)
189+
]
190+
),
191+
]
192+
assert buffer.getvalue() == data
193+
assert second_buffer.getvalue() == data[10:16]
194+
195+
@mock.patch(
196+
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader.generate_random_56_bit_integer"
197+
)
110198
@mock.patch(
111199
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream"
112200
)
@@ -115,7 +203,7 @@ async def test_create_mrd(
115203
)
116204
@pytest.mark.asyncio
117205
async def test_download_ranges(
118-
self, mock_grpc_client, mock_cls_async_read_object_stream
206+
self, mock_grpc_client, mock_cls_async_read_object_stream, mock_random_int
119207
):
120208
# Arrange
121209
data = b"these_are_18_chars"
@@ -125,6 +213,7 @@ async def test_download_ranges(
125213
mock_mrd = await self._make_mock_mrd(
126214
mock_grpc_client, mock_cls_async_read_object_stream
127215
)
216+
mock_random_int.side_effect = [123, 456] # for _func_id and read_id
128217
mock_mrd.read_obj_str.send = AsyncMock()
129218
mock_mrd.read_obj_str.recv = AsyncMock()
130219
mock_mrd.read_obj_str.recv.return_value = _storage_v2.BidiReadObjectResponse(
@@ -135,27 +224,24 @@ async def test_download_ranges(
135224
),
136225
range_end=True,
137226
read_range=_storage_v2.ReadRange(
138-
read_offset=0, read_length=18, read_id=0
227+
read_offset=0, read_length=18, read_id=456
139228
),
140229
)
141230
],
142231
)
143232

144233
# Act
145234
buffer = BytesIO()
146-
results = await mock_mrd.download_ranges([(0, 18, buffer)])
235+
await mock_mrd.download_ranges([(0, 18, buffer)])
147236

148237
# Assert
149238
mock_mrd.read_obj_str.send.assert_called_once_with(
150239
_storage_v2.BidiReadObjectRequest(
151240
read_ranges=[
152-
_storage_v2.ReadRange(read_offset=0, read_length=18, read_id=0)
241+
_storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456)
153242
]
154243
)
155244
)
156-
assert len(results) == 1
157-
assert results[0].bytes_requested == 18
158-
assert results[0].bytes_written == 18
159245
assert buffer.getvalue() == data
160246

161247
@mock.patch(

0 commit comments

Comments
 (0)