Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,8 @@ async def _execute_command(
if result.get("error"):
error = result["error"]
retryable_top_level_error = (
isinstance(error.details, dict)
hasattr(error, "details")
and isinstance(error.details, dict)
and error.details.get("code", 0) in _RETRYABLE_ERROR_CODES
)
retryable_network_error = isinstance(
Expand Down
4 changes: 3 additions & 1 deletion pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2562,7 +2562,9 @@ async def run(self) -> T:
if not self._retryable:
raise
if isinstance(exc, ClientBulkWriteException) and exc.error:
retryable_write_error_exc = exc.error.has_error_label("RetryableWriteError")
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
Expand Down
3 changes: 2 additions & 1 deletion pymongo/synchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ def _execute_command(
if result.get("error"):
error = result["error"]
retryable_top_level_error = (
isinstance(error.details, dict)
hasattr(error, "details")
and isinstance(error.details, dict)
and error.details.get("code", 0) in _RETRYABLE_ERROR_CODES
)
retryable_network_error = isinstance(
Expand Down
4 changes: 3 additions & 1 deletion pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2549,7 +2549,9 @@ def run(self) -> T:
if not self._retryable:
raise
if isinstance(exc, ClientBulkWriteException) and exc.error:
retryable_write_error_exc = exc.error.has_error_label("RetryableWriteError")
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
Expand Down
16 changes: 15 additions & 1 deletion test/asynchronous/pymongo_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
from test.asynchronous import async_client_context

from pymongo import AsyncMongoClient, common
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.monitor import Monitor
from pymongo.asynchronous.pool import Pool
from pymongo.errors import AutoReconnect, NetworkTimeout
from pymongo.errors import AutoReconnect, ClientBulkWriteException, NetworkTimeout
from pymongo.hello import Hello, HelloCompat
from pymongo.operations import _Op
from pymongo.server_description import ServerDescription
from pymongo.write_concern import WriteConcern

_IS_SYNC = False

Expand Down Expand Up @@ -246,7 +249,18 @@ def mock_hello(self, host):

return response, rtt

async def mock_client_bulk_write(self, models):
blk = _AsyncMockClientBulk(self, write_concern=WriteConcern(w=1))
for model in models:
model._add_to_client_bulk(blk)
return await blk.execute(None, _Op.BULK_WRITE)

def _process_periodic_tasks(self):
# Avoid the background thread causing races, e.g. a surprising
# reconnect while we're trying to test a disconnected client.
pass


class _AsyncMockClientBulk(_AsyncClientBulk):
async def write_command(self, bwc, cmd, request_id, msg, to_send_ops, to_send_ns, client):
return {"error": TypeError("mock type error")}
27 changes: 26 additions & 1 deletion test/asynchronous/test_client_bulk_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@

sys.path[0:0] = [""]

from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous import (
AsyncIntegrationTest,
AsyncMockClientTest,
async_client_context,
unittest,
)
from test.asynchronous.pymongo_mocks import AsyncMockClient
from test.utils import (
OvertCommandListener,
async_rs_or_single_client,
Expand Down Expand Up @@ -577,3 +583,22 @@ async def test_timeout_in_multi_batch_bulk_write(self):
if event.command_name == "bulkWrite":
bulk_write_events.append(event)
self.assertEqual(len(bulk_write_events), 2)


class TestClientBulkWriteMock(AsyncMockClientTest):
@async_client_context.require_version_min(8, 0, 0, -24)
async def test_handles_non_pymongo_error(self):
mock_client = await AsyncMockClient.get_async_mock_client(
standalones=[],
members=["a:1", "b:2", "c:3"],
mongoses=[],
host="b:2", # Pass a secondary.
replicaSet="rs",
heartbeatFrequencyMS=500,
)
self.addAsyncCleanup(mock_client.close)
models = [InsertOne(namespace="db.coll", document={"a": "b"})]
with self.assertRaises(ClientBulkWriteException) as context:
await mock_client.mock_client_bulk_write(models=models)
self.assertIsInstance(context.exception.error, TypeError)
self.assertFalse(hasattr(context.exception.error, "details"))
16 changes: 15 additions & 1 deletion test/pymongo_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
from test import client_context

from pymongo import MongoClient, common
from pymongo.errors import AutoReconnect, NetworkTimeout
from pymongo.errors import AutoReconnect, ClientBulkWriteException, NetworkTimeout
from pymongo.hello import Hello, HelloCompat
from pymongo.operations import _Op
from pymongo.server_description import ServerDescription
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.monitor import Monitor
from pymongo.synchronous.pool import Pool
from pymongo.write_concern import WriteConcern

_IS_SYNC = True

Expand Down Expand Up @@ -245,7 +248,18 @@ def mock_hello(self, host):

return response, rtt

def mock_client_bulk_write(self, models):
blk = _MockClientBulk(self, write_concern=WriteConcern(w=1))
for model in models:
model._add_to_client_bulk(blk)
return blk.execute(None, _Op.BULK_WRITE)

def _process_periodic_tasks(self):
# Avoid the background thread causing races, e.g. a surprising
# reconnect while we're trying to test a disconnected client.
pass


class _MockClientBulk(_ClientBulk):
def write_command(self, bwc, cmd, request_id, msg, to_send_ops, to_send_ns, client):
return {"error": TypeError("mock type error")}
27 changes: 26 additions & 1 deletion test/test_client_bulk_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@

sys.path[0:0] = [""]

from test import IntegrationTest, client_context, unittest
from test import (
IntegrationTest,
MockClientTest,
client_context,
unittest,
)
from test.pymongo_mocks import MockClient
from test.utils import (
OvertCommandListener,
rs_or_single_client,
Expand Down Expand Up @@ -577,3 +583,22 @@ def test_timeout_in_multi_batch_bulk_write(self):
if event.command_name == "bulkWrite":
bulk_write_events.append(event)
self.assertEqual(len(bulk_write_events), 2)


class TestClientBulkWriteMock(MockClientTest):
@client_context.require_version_min(8, 0, 0, -24)
def test_handles_non_pymongo_error(self):
mock_client = MockClient.get_mock_client(
standalones=[],
members=["a:1", "b:2", "c:3"],
mongoses=[],
host="b:2", # Pass a secondary.
replicaSet="rs",
heartbeatFrequencyMS=500,
)
self.addCleanup(mock_client.close)
models = [InsertOne(namespace="db.coll", document={"a": "b"})]
with self.assertRaises(ClientBulkWriteException) as context:
mock_client.mock_client_bulk_write(models=models)
self.assertIsInstance(context.exception.error, TypeError)
self.assertFalse(hasattr(context.exception.error, "details"))