Skip to content

Commit c333423

Browse files
committed
Change FieldExpr to TypeVar to work around Mappings being invariant in the key type
See: - python/typing#445 - python/typing#273
1 parent dde2ed0 commit c333423

File tree

6 files changed

+27
-31
lines changed

6 files changed

+27
-31
lines changed

beanie/odm/documents.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
wrap_with_actions,
4747
)
4848
from beanie.odm.bulk import BulkWriter, Operation
49-
from beanie.odm.fields import FieldExpr, IndexModel, PydanticObjectId
49+
from beanie.odm.fields import IndexModel, PydanticObjectId
5050
from beanie.odm.interfaces.find import BaseSettings, FindInterface
5151
from beanie.odm.interfaces.update import UpdateMethods
5252
from beanie.odm.links import Link, LinkedModelMixin, LinkInfo, LinkTypes
@@ -417,8 +417,7 @@ async def replace(
417417
)
418418

419419
use_revision_id = self._settings.use_revision
420-
find_query: Dict[FieldExpr, Any] = {"_id": self.id}
421-
420+
find_query = {"_id": self.id}
422421
if use_revision_id and not ignore_revision:
423422
find_query["revision_id"] = self._previous_revision_id
424423
try:
@@ -548,7 +547,7 @@ async def replace_many(
548547
:return: None
549548
"""
550549
ids_list = [document.id for document in documents]
551-
if await cls.find(In(cls.id, ids_list)).count() != len(ids_list): # type: ignore[arg-type]
550+
if await cls.find(In("_id", ids_list)).count() != len(ids_list):
552551
raise ReplaceError(
553552
"Some of the documents are not exist in the collection"
554553
)
@@ -582,7 +581,7 @@ async def update(
582581
arguments = list(args)
583582
use_revision_id = self._settings.use_revision
584583

585-
find_query: Dict[FieldExpr, Any] = {
584+
find_query = {
586585
"_id": self.id if self.id is not None else PydanticObjectId()
587586
}
588587
if use_revision_id and not ignore_revision:

beanie/odm/fields.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import Enum
22
from functools import cached_property
3-
from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union
3+
from typing import Any, Dict, Iterator, List, Mapping, Tuple, TypeVar, Union
44

55
import bson
66
import pymongo
@@ -97,7 +97,7 @@ def __deepcopy__(self, memo: dict) -> Self:
9797
return self
9898

9999

100-
FieldExpr = Union[ExpressionField, str]
100+
FieldExpr = TypeVar("FieldExpr", bound=Union[ExpressionField, str])
101101

102102

103103
def convert_field_exprs_to_str(expression: Any) -> Any:

beanie/odm/interfaces/update.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import abstractmethod
2-
from typing import Any, Mapping, Optional, cast
2+
from typing import Any, Mapping, Optional
33

44
from pymongo.client_session import ClientSession
55

@@ -51,7 +51,7 @@ class Sample(Document):
5151
:return: self
5252
"""
5353
return self.update(
54-
cast(Mapping[FieldExpr, Any], Set(expression)),
54+
Set(expression),
5555
session=session,
5656
bulk_writer=bulk_writer,
5757
**pymongo_kwargs,
@@ -75,7 +75,7 @@ def current_date(
7575
:return: self
7676
"""
7777
return self.update(
78-
cast(Mapping[FieldExpr, Any], CurrentDate(expression)),
78+
CurrentDate(expression),
7979
session=session,
8080
bulk_writer=bulk_writer,
8181
**pymongo_kwargs,
@@ -110,7 +110,7 @@ class Sample(Document):
110110
:return: self
111111
"""
112112
return self.update(
113-
cast(Mapping[FieldExpr, Any], Inc(expression)),
113+
Inc(expression),
114114
session=session,
115115
bulk_writer=bulk_writer,
116116
**pymongo_kwargs,

beanie/odm/links.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,17 @@ def eval_type(cls, t: Any) -> Type["Document"]:
272272

273273
async def fetch_link(self, field: FieldExpr) -> None:
274274
if isinstance(field, ExpressionField):
275-
field = str(field)
276-
ref_obj = getattr(self, field, None)
275+
attr = str(field)
276+
else:
277+
assert isinstance(field, str)
278+
attr = field
279+
ref_obj = getattr(self, attr, None)
277280
if isinstance(ref_obj, Link):
278281
value = await ref_obj.fetch(fetch_links=True)
279-
setattr(self, field, value)
282+
setattr(self, attr, value)
280283
elif isinstance(ref_obj, list) and ref_obj:
281284
values = await Link.fetch_list(ref_obj, fetch_links=True)
282-
setattr(self, field, values)
285+
setattr(self, attr, values)
283286

284287
async def fetch_all_links(self) -> None:
285288
await asyncio.gather(

beanie/odm/queries/find/many.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import beanie
2020
from beanie.odm.bulk import BulkWriter
2121
from beanie.odm.fields import (
22-
ExpressionField,
2322
FieldExpr,
2423
SortDirection,
2524
convert_field_exprs_to_str,
@@ -145,20 +144,18 @@ def sort(
145144
pass
146145
elif isinstance(arg, list):
147146
self.sort(*arg)
148-
elif isinstance(arg, tuple):
149-
self._add_sort(*arg)
150147
else:
151-
self._add_sort(arg)
148+
if isinstance(arg, tuple):
149+
key, direction = arg
150+
else:
151+
key = arg
152+
direction = None
153+
self._add_sort(convert_field_exprs_to_str(key), direction)
152154
return self
153155

154-
def _add_sort(
155-
self, key: FieldExpr, direction: Optional[SortDirection] = None
156-
):
157-
if isinstance(key, ExpressionField):
158-
key = str(key)
159-
elif not isinstance(key, str):
156+
def _add_sort(self, key: str, direction: Optional[SortDirection]) -> None:
157+
if not isinstance(key, str):
160158
raise TypeError(f"Sort key must be a string, not {type(key)}")
161-
162159
if direction is None:
163160
if key.startswith("-"):
164161
direction = SortDirection.DESCENDING

beanie/odm/queries/find/one.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
Any,
44
Generator,
55
Generic,
6-
List,
76
Mapping,
87
Optional,
98
Type,
@@ -202,9 +201,8 @@ async def count(self) -> int:
202201
:return: int
203202
"""
204203
if self.fetch_links:
205-
args = cast(List[Mapping[FieldExpr, Any]], self.find_expressions)
206204
return await self.document_model.find_many(
207-
*args,
205+
*self.find_expressions,
208206
session=self.session,
209207
fetch_links=self.fetch_links,
210208
**self.pymongo_kwargs,
@@ -226,9 +224,8 @@ async def _find(self, use_cache: bool, parse: bool) -> Optional[ModelT]:
226224
doc = await self._find(use_cache=False, parse=False)
227225
cache.set(cache_key, doc)
228226
elif self.fetch_links:
229-
args = cast(List[Mapping[FieldExpr, Any]], self.find_expressions)
230227
doc = await self.document_model.find_many(
231-
*args,
228+
*self.find_expressions,
232229
session=self.session,
233230
fetch_links=self.fetch_links,
234231
projection_model=self.projection_model,

0 commit comments

Comments
 (0)