1515import pytest
1616from unittest import mock
1717
18+ from unittest .mock import AsyncMock
1819from google .cloud .storage ._experimental .asyncio .async_write_object_stream import (
1920 _AsyncWriteObjectStream ,
2021)
@@ -43,6 +44,27 @@ def mock_client():
4344 return client
4445
4546
47+ async def instantiate_write_obj_stream (mock_client , mock_cls_async_bidi_rpc , open = True ):
48+ """Helper to create an instance of _AsyncWriteObjectStream and open it by default."""
49+ socket_like_rpc = AsyncMock ()
50+ mock_cls_async_bidi_rpc .return_value = socket_like_rpc
51+ socket_like_rpc .open = AsyncMock ()
52+ socket_like_rpc .close = AsyncMock ()
53+
54+ mock_response = mock .MagicMock (spec = _storage_v2 .BidiWriteObjectResponse )
55+ mock_response .resource = mock .MagicMock (spec = _storage_v2 .Object )
56+ mock_response .resource .generation = GENERATION
57+ mock_response .write_handle = WRITE_HANDLE
58+ socket_like_rpc .recv = AsyncMock (return_value = mock_response )
59+
60+ write_obj_stream = _AsyncWriteObjectStream (mock_client , BUCKET , OBJECT )
61+
62+ if open :
63+ await write_obj_stream .open ()
64+
65+ return write_obj_stream
66+
67+
4668def test_async_write_object_stream_init (mock_client ):
4769 """Test the constructor of _AsyncWriteObjectStream."""
4870 stream = _AsyncWriteObjectStream (mock_client , BUCKET , OBJECT )
@@ -228,7 +250,6 @@ async def test_open_raises_error_on_missing_generation(
228250 ValueError , match = "Failed to obtain object generation after opening the stream"
229251 ):
230252 await stream .open ()
231- # assert stream.generation_number is None
232253
233254
234255@pytest .mark .asyncio
@@ -252,13 +273,49 @@ async def test_open_raises_error_on_missing_write_handle(
252273
253274
254275@pytest .mark .asyncio
255- async def test_unimplemented_methods_raise_error (mock_client ):
256- """Test that unimplemented methods raise NotImplementedError."""
257- stream = _AsyncWriteObjectStream (mock_client , BUCKET , OBJECT )
276+ @mock .patch (
277+ "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
278+ )
279+ async def test_close (mock_cls_async_bidi_rpc , mock_client ):
280+ """Test that close successfully closes the stream."""
281+ # Arrange
282+ write_obj_stream = await instantiate_write_obj_stream (
283+ mock_client , mock_cls_async_bidi_rpc , open = True
284+ )
258285
259- with pytest . raises ( NotImplementedError ):
260- await stream .close ()
286+ # Act
287+ await write_obj_stream .close ()
261288
289+ # Assert
290+ write_obj_stream .socket_like_rpc .close .assert_called_once ()
291+ assert not write_obj_stream .is_stream_open
292+
293+
294+ @pytest .mark .asyncio
295+ @mock .patch (
296+ "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
297+ )
298+ async def test_close_without_open_should_raise_error (
299+ mock_cls_async_bidi_rpc , mock_client
300+ ):
301+ """Test that closing a stream that is not open raises a ValueError."""
302+ # Arrange
303+ write_obj_stream = await instantiate_write_obj_stream (
304+ mock_client , mock_cls_async_bidi_rpc , open = False
305+ )
306+
307+ # Act & Assert
308+ with pytest .raises (ValueError , match = "Stream is not open" ):
309+ await write_obj_stream .close ()
310+
311+
312+ @pytest .mark .asyncio
313+ @mock .patch (
314+ "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
315+ )
316+ async def test_unimplemented_methods_raise_error (mock_async_bidi_rpc , mock_client ):
317+ """Test that unimplemented methods (send, recv) raise NotImplementedError."""
318+ stream = _AsyncWriteObjectStream (mock_client , BUCKET , OBJECT )
262319 with pytest .raises (NotImplementedError ):
263320 await stream .send (_storage_v2 .BidiWriteObjectRequest ())
264321
0 commit comments