Skip to content

Commit 7a53221

Browse files
authored
feat(zb-experimental): implement close (#1614)
feat(zb-experimental): implement close
1 parent 786af55 commit 7a53221

File tree

2 files changed

+67
-9
lines changed

2 files changed

+67
-9
lines changed

google/cloud/storage/_experimental/asyncio/async_write_object_stream.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,10 @@ async def open(self) -> None:
145145

146146
async def close(self) -> None:
147147
"""Closes the bidi-gRPC connection."""
148-
raise NotImplementedError(
149-
"close() is not implemented yet in _AsyncWriteObjectStream"
150-
)
148+
if not self._is_stream_open:
149+
raise ValueError("Stream is not open")
150+
await self.socket_like_rpc.close()
151+
self._is_stream_open = False
151152

152153
async def send(
153154
self, bidi_write_object_request: _storage_v2.BidiWriteObjectRequest

tests/unit/asyncio/test_async_write_object_stream.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616
from unittest import mock
1717

18+
from unittest.mock import AsyncMock
1819
from google.cloud.storage._experimental.asyncio.async_write_object_stream import (
1920
_AsyncWriteObjectStream,
2021
)
@@ -43,6 +44,27 @@ def mock_client():
4344
return client
4445

4546

47+
async def instantiate_write_obj_stream(mock_client, mock_cls_async_bidi_rpc, open=True):
48+
"""Helper to create an instance of _AsyncWriteObjectStream and open it by default."""
49+
socket_like_rpc = AsyncMock()
50+
mock_cls_async_bidi_rpc.return_value = socket_like_rpc
51+
socket_like_rpc.open = AsyncMock()
52+
socket_like_rpc.close = AsyncMock()
53+
54+
mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse)
55+
mock_response.resource = mock.MagicMock(spec=_storage_v2.Object)
56+
mock_response.resource.generation = GENERATION
57+
mock_response.write_handle = WRITE_HANDLE
58+
socket_like_rpc.recv = AsyncMock(return_value=mock_response)
59+
60+
write_obj_stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
61+
62+
if open:
63+
await write_obj_stream.open()
64+
65+
return write_obj_stream
66+
67+
4668
def test_async_write_object_stream_init(mock_client):
4769
"""Test the constructor of _AsyncWriteObjectStream."""
4870
stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
@@ -228,7 +250,6 @@ async def test_open_raises_error_on_missing_generation(
228250
ValueError, match="Failed to obtain object generation after opening the stream"
229251
):
230252
await stream.open()
231-
# assert stream.generation_number is None
232253

233254

234255
@pytest.mark.asyncio
@@ -252,13 +273,49 @@ async def test_open_raises_error_on_missing_write_handle(
252273

253274

254275
@pytest.mark.asyncio
255-
async def test_unimplemented_methods_raise_error(mock_client):
256-
"""Test that unimplemented methods raise NotImplementedError."""
257-
stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
276+
@mock.patch(
277+
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
278+
)
279+
async def test_close(mock_cls_async_bidi_rpc, mock_client):
280+
"""Test that close successfully closes the stream."""
281+
# Arrange
282+
write_obj_stream = await instantiate_write_obj_stream(
283+
mock_client, mock_cls_async_bidi_rpc, open=True
284+
)
258285

259-
with pytest.raises(NotImplementedError):
260-
await stream.close()
286+
# Act
287+
await write_obj_stream.close()
261288

289+
# Assert
290+
write_obj_stream.socket_like_rpc.close.assert_called_once()
291+
assert not write_obj_stream.is_stream_open
292+
293+
294+
@pytest.mark.asyncio
295+
@mock.patch(
296+
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
297+
)
298+
async def test_close_without_open_should_raise_error(
299+
mock_cls_async_bidi_rpc, mock_client
300+
):
301+
"""Test that closing a stream that is not open raises a ValueError."""
302+
# Arrange
303+
write_obj_stream = await instantiate_write_obj_stream(
304+
mock_client, mock_cls_async_bidi_rpc, open=False
305+
)
306+
307+
# Act & Assert
308+
with pytest.raises(ValueError, match="Stream is not open"):
309+
await write_obj_stream.close()
310+
311+
312+
@pytest.mark.asyncio
313+
@mock.patch(
314+
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
315+
)
316+
async def test_unimplemented_methods_raise_error(mock_async_bidi_rpc, mock_client):
317+
"""Test that unimplemented methods (send, recv) raise NotImplementedError."""
318+
stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
262319
with pytest.raises(NotImplementedError):
263320
await stream.send(_storage_v2.BidiWriteObjectRequest())
264321

0 commit comments

Comments
 (0)