From 42ef61d4c66090a98a6f3a3c1fd98482ac582eb0 Mon Sep 17 00:00:00 2001 From: slorello89 Date: Thu, 4 Apr 2024 15:01:51 -0400 Subject: [PATCH 01/11] refactoring model to use pydantic 2.0 types and validators --- aredis_om/_compat.py | 94 +++++++++++++- aredis_om/model/model.py | 213 ++++++++++++++++++++++++++----- pyproject.toml | 4 +- tests/_compat.py | 9 +- tests/test_hash_model.py | 8 +- tests/test_json_model.py | 27 ++-- tests/test_oss_redis_features.py | 4 +- 7 files changed, 298 insertions(+), 61 deletions(-) diff --git a/aredis_om/_compat.py b/aredis_om/_compat.py index 0246e4f8..0fe4f486 100644 --- a/aredis_om/_compat.py +++ b/aredis_om/_compat.py @@ -1,19 +1,99 @@ from pydantic.version import VERSION as PYDANTIC_VERSION +from typing_extensions import Annotated, Literal, get_args, get_origin +from dataclasses import dataclass, is_dataclass +from typing import ( + Any, + Callable, + Deque, + Dict, + FrozenSet, + List, + Mapping, + Sequence, + Set, + Tuple, + Type, + Union, +) + PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") if PYDANTIC_V2: + def use_pydantic_2_plus(): + return True + from pydantic import BaseModel, validator + from pydantic._internal._model_construction import ModelMetaclass + from pydantic.fields import FieldInfo + from pydantic_core import PydanticUndefined as Undefined, PydanticUndefinedType as UndefinedType + from pydantic.deprecated.json import ENCODERS_BY_TYPE + from pydantic import TypeAdapter + from pydantic import ValidationError as ValidationError + + + from pydantic.v1.main import validate_model + + from pydantic.v1.typing import NoArgAnyCallable + from pydantic._internal._repr import Representation + + @dataclass + class ModelField: + field_info: FieldInfo + name: str + mode: Literal["validation", "serialization"] = "validation" + + @property + def alias(self) -> str: + a = self.field_info.alias + return a if a is not None else self.name + + @property + def required(self) -> bool: + return self.field_info.is_required() + + @property + def default(self) -> Any: + return self.get_default() + + @property + def type_(self) -> Any: + return self.field_info.annotation + + def __post_init__(self) -> None: + self._type_adapter: TypeAdapter[Any] = TypeAdapter( + Annotated[self.field_info.annotation, self.field_info] + ) + + def get_default(self) -> Any: + if self.field_info.is_required(): + return Undefined + return self.field_info.get_default(call_default_factory=True) + + def validate( + self, + value: Any, + values: Dict[str, Any] = {}, # noqa: B006 + *, + loc: Tuple[Union[int, str], ...] = (), + ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: + return ( + self._type_adapter.validate_python(value, from_attributes=True), + None, + ) + + def __hash__(self) -> int: + # Each ModelField is unique for our purposes, to allow making a dict from + # ModelField to its JSON Schema. + return id(self) + + +else: from pydantic.v1 import BaseModel, validator from pydantic.v1.fields import FieldInfo, ModelField, Undefined, UndefinedType from pydantic.v1.json import ENCODERS_BY_TYPE from pydantic.v1.main import ModelMetaclass, validate_model from pydantic.v1.typing import NoArgAnyCallable from pydantic.v1.utils import Representation -else: - from pydantic import BaseModel, validator - from pydantic.fields import FieldInfo, ModelField, Undefined, UndefinedType - from pydantic.json import ENCODERS_BY_TYPE - from pydantic.main import ModelMetaclass, validate_model - from pydantic.typing import NoArgAnyCallable - from pydantic.utils import Representation + def use_pydantic_2_plus(): + return False diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index a90b3971..24d59b04 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -75,6 +75,14 @@ ERRORS_URL = "https://github.com/redis/redis-om-python/blob/main/docs/errors.md" +def get_outer_type(field): + if hasattr(field, 'outer_type_'): + return field.outer_type_ + elif isinstance(field.annotation, type) or is_supported_container_type(field.annotation): + return field.annotation + else: + return field.annotation.__args__[0] + class RedisModelError(Exception): """Raised when a problem exists in the definition of a RedisModel.""" @@ -106,7 +114,7 @@ def __str__(self): return str(self.name) -ExpressionOrModelField = Union["Expression", "NegatedExpression", ModelField] +ExpressionOrModelField = Union["Expression", "NegatedExpression", ModelField, PydanticFieldInfo] def embedded(cls): @@ -130,6 +138,9 @@ def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any if "__" in field_name: obj = model for sub_field in field_name.split("__"): + if not isinstance(obj, ModelMeta) and hasattr(obj, 'field'): + obj = getattr(obj, 'field').annotation + if not hasattr(obj, sub_field): raise QuerySyntaxError( f"The update path {field_name} contains a field that does not " @@ -331,8 +342,11 @@ def __rshift__(self, other: Any) -> Expression: ) def __getattr__(self, item): - if is_supported_container_type(self.field.outer_type_): - embedded_cls = get_args(self.field.outer_type_) + if item.startswith("__"): + raise AttributeError("cannot invoke __getattr__ with reserved field") + outer_type = outer_type_or_annotation(self.field) + if is_supported_container_type(outer_type): + embedded_cls = get_args(outer_type) if not embedded_cls: raise QuerySyntaxError( "In order to query on a list field, you must define " @@ -342,9 +356,9 @@ def __getattr__(self, item): embedded_cls = embedded_cls[0] attr = getattr(embedded_cls, item) else: - attr = getattr(self.field.outer_type_, item) + attr = getattr(outer_type, item) if isinstance(attr, self.__class__): - new_parent = (self.field.name, self.field.outer_type_) + new_parent = (self.field.alias, outer_type) if new_parent not in attr.parents: attr.parents.append(new_parent) new_parents = list(set(self.parents) - set(attr.parents)) @@ -480,7 +494,12 @@ def validate_sort_fields(self, sort_fields: List[str]): f"does not exist on the model {self.model}" ) field_proxy = getattr(self.model, field_name) - if not getattr(field_proxy.field.field_info, "sortable", False): + if isinstance(field_proxy.field, FieldInfo) or isinstance(field_proxy.field, PydanticFieldInfo): + field_info = field_proxy.field + else: + field_info = field_proxy.field.field_info + + if not getattr(field_info, "sortable", False): raise QueryNotSupportedError( f"You tried sort by {field_name}, but {self.model} does " f"not define that field as sortable. Docs: {ERRORS_URL}#E2" @@ -489,10 +508,14 @@ def validate_sort_fields(self, sort_fields: List[str]): @staticmethod def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes: - if getattr(field.field_info, "primary_key", None) is True: + if not hasattr(field, 'field_info'): + field_info = field + else: + field_info = field.field_info + if getattr(field_info, "primary_key", None) is True: return RediSearchFieldTypes.TAG elif op is Operators.LIKE: - fts = getattr(field.field_info, "full_text_search", None) + fts = getattr(field_info, "full_text_search", None) if fts is not True: # Could be PydanticUndefined raise QuerySyntaxError( f"You tried to do a full-text search on the field '{field.name}', " @@ -501,7 +524,7 @@ def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes ) return RediSearchFieldTypes.TEXT - field_type = field.outer_type_ + field_type = outer_type_or_annotation(field) # TODO: GEO fields container_type = get_origin(field_type) @@ -729,6 +752,15 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: f"You tried to query by a field ({field_name}) " f"that isn't indexed. Docs: {ERRORS_URL}#E6" ) + elif isinstance(expression.left, FieldInfo): + field_type = cls.resolve_field_type(expression.left, expression.op) + field_name = expression.left.alias + field_info = expression.left + if not field_info or not getattr(field_info, "index", None): + raise QueryNotSupportedError( + f"You tried to query by a field ({field_name}) " + f"that isn't indexed. Docs: {ERRORS_URL}#E6" + ) else: raise QueryNotSupportedError( "A query expression should start with either a field " @@ -1156,7 +1188,7 @@ def Field( vector_options=vector_options, **current_schema_extra, ) - field_info._validate() + # field_info._validate() return field_info @@ -1232,6 +1264,17 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 # Create proxies for each model field so that we can use the field # in queries, like Model.get(Model.field_name == 1) for field_name, field in new_class.__fields__.items(): + if not isinstance(field, FieldInfo): + for base_candidate in bases: + if hasattr(base_candidate, field_name): + inner_field = getattr(base_candidate, field_name) + if hasattr(inner_field, 'field') and isinstance(getattr(inner_field, 'field'), FieldInfo): + field.metadata.append(getattr(inner_field, 'field')) + field = getattr(inner_field, 'field') + + + if not field.alias: + field.alias = field_name setattr(new_class, field_name, ExpressionProxy(field, [])) annotation = new_class.get_annotations().get(field_name) if annotation: @@ -1241,12 +1284,20 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 else: new_class.__annotations__[field_name] = ExpressionProxy # Check if this is our FieldInfo version with extended ORM metadata. - if isinstance(field.field_info, FieldInfo): - if field.field_info.primary_key: + # if isinstance(field.field_info, FieldInfo): + field_info = None + if hasattr(field, 'field_info') and isinstance(field.field_info, FieldInfo): + field_info = field.field_info + elif field_name in attrs and isinstance(attrs.__getitem__(field_name), FieldInfo): + field_info = attrs.__getitem__(field_name) + field.field_info = field_info + + if field_info is not None: + if field_info.primary_key: new_class._meta.primary_key = PrimaryKey( name=field_name, field=field ) - if field.field_info.vector_options: + if field_info.vector_options: score_attr = f"_{field_name}_score" setattr(new_class, score_attr, None) new_class.__annotations__[score_attr] = Union[float, None] @@ -1290,8 +1341,21 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 return new_class +def outer_type_or_annotation(field): + if hasattr(field, 'outer_type_'): + return field.outer_type_ + elif not hasattr(field.annotation, '__args__'): + if not isinstance(field.annotation, type): + raise AttributeError(f"could not extract outer type from field {field}") + return field.annotation + else: + return field.annotation.__args__[0] + + class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): +# class RedisModel(BaseModel, abc.ABC): pk: Optional[str] = Field(default=None, primary_key=True) + # pk: Optional[str] = Field(default=None, primary_key=True) Meta = DefaultMeta @@ -1310,7 +1374,10 @@ def __lt__(self, other): def key(self): """Return the Redis key for this model.""" - pk = getattr(self, self._meta.primary_key.field.name) + if hasattr(self._meta.primary_key.field, 'name'): + pk = getattr(self, self._meta.primary_key.field.name) + else: + pk = getattr(self, self._meta.primary_key.name) return self.make_primary_key(pk) @classmethod @@ -1349,7 +1416,7 @@ async def expire( @validator("pk", always=True, allow_reuse=True) def validate_pk(cls, v): - if not v: + if not v or isinstance(v, ExpressionProxy): v = cls._meta.primary_key_creator_cls().create_pk() return v @@ -1358,7 +1425,15 @@ def validate_primary_key(cls): """Check for a primary key. We need one (and only one).""" primary_keys = 0 for name, field in cls.__fields__.items(): - if getattr(field.field_info, "primary_key", None): + if not hasattr(field, 'field_info'): + if not isinstance(field, FieldInfo) and hasattr(field, 'metadata') and len(field.metadata) > 0 and isinstance(field.metadata[0], FieldInfo): + field_info = field.metadata[0] + else: + field_info = field + else: + field_info = field.field_info + + if getattr(field_info, "primary_key", None): primary_keys += 1 if primary_keys == 0: raise RedisModelError("You must define a primary key for the model") @@ -1490,17 +1565,37 @@ def redisearch_schema(cls): def check(self): """Run all validations.""" - *_, validation_error = validate_model(self.__class__, self.__dict__) - if validation_error: - raise validation_error + from pydantic.version import VERSION as PYDANTIC_VERSION + PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + if not PYDANTIC_V2: + *_, validation_error = validate_model(self.__class__, self.__dict__) + if validation_error: + raise validation_error class HashModel(RedisModel, abc.ABC): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) + if hasattr(cls, '__annotations__'): + for name, field_type in cls.__annotations__.items(): + origin = get_origin(field_type) + for typ in (Set, Mapping, List): + if isinstance(origin, type) and issubclass(origin, typ): + raise RedisModelError( + f"HashModels cannot index set, list," + f" or mapping fields. Field: {name}" + ) + if isinstance(field_type, type) and issubclass(field_type, RedisModel): + raise RedisModelError(f"HashModels cannot index embedded model fields. Field: {name}") + elif isinstance(field_type, type) and dataclasses.is_dataclass(field_type): + raise RedisModelError( + f"HashModels cannot index dataclass fields. Field: {name}" + ) + for name, field in cls.__fields__.items(): - origin = get_origin(field.outer_type_) + outer_type = outer_type_or_annotation(field) + origin = get_origin(outer_type) if origin: for typ in (Set, Mapping, List): if issubclass(origin, typ): @@ -1509,11 +1604,11 @@ def __init_subclass__(cls, **kwargs): f" or mapping fields. Field: {name}" ) - if issubclass(field.outer_type_, RedisModel): + if issubclass(outer_type, RedisModel): raise RedisModelError( f"HashModels cannot index embedded model fields. Field: {name}" ) - elif dataclasses.is_dataclass(field.outer_type_): + elif dataclasses.is_dataclass(outer_type): raise RedisModelError( f"HashModels cannot index dataclass fields. Field: {name}" ) @@ -1524,6 +1619,10 @@ async def save( self.check() db = self._get_db(pipeline) + # if hasattr(self,'model_fields_set'): + # dict = {k: v for k, v in self.dict().items() if k in self.model_fields_set} + # else: + # dict = self.dict() document = jsonable_encoder(self.dict()) # TODO: Wrap any Redis response errors in a custom exception? await db.hset(self.key(), mapping=document) @@ -1594,21 +1693,29 @@ def schema_for_fields(cls): for name, field in cls.__fields__.items(): # TODO: Merge this code with schema_for_type()? - _type = field.outer_type_ + _type = outer_type_or_annotation(field) is_subscripted_type = get_origin(_type) - if getattr(field.field_info, "primary_key", None): + if not isinstance(field, FieldInfo) and hasattr(field, 'metadata') and len(field.metadata) > 0 and isinstance(field.metadata[0], FieldInfo): + field = field.metadata[0] + + if not hasattr(field, 'field_info'): + field_info = field + else: + field_info = field.field_info + + if getattr(field_info, "primary_key", None): if issubclass(_type, str): redisearch_field = ( f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" ) else: redisearch_field = cls.schema_for_type( - name, _type, field.field_info + name, _type, field_info ) schema_parts.append(redisearch_field) - elif getattr(field.field_info, "index", None) is True: - schema_parts.append(cls.schema_for_type(name, _type, field.field_info)) + elif getattr(field_info, "index", None) is True: + schema_parts.append(cls.schema_for_type(name, _type, field_info)) elif is_subscripted_type: # Ignore subscripted types (usually containers!) that we don't # support, for the purposes of indexing. @@ -1622,10 +1729,10 @@ def schema_for_fields(cls): continue embedded_cls = embedded_cls[0] schema_parts.append( - cls.schema_for_type(name, embedded_cls, field.field_info) + cls.schema_for_type(name, embedded_cls, field_info) ) elif issubclass(_type, RedisModel): - schema_parts.append(cls.schema_for_type(name, _type, field.field_info)) + schema_parts.append(cls.schema_for_type(name, _type, field_info)) return schema_parts @classmethod @@ -1760,11 +1867,41 @@ def redisearch_schema(cls): def schema_for_fields(cls): schema_parts = [] json_path = "$" - + fields = dict() for name, field in cls.__fields__.items(): - _type = field.outer_type_ + fields[name] = field + for name, field in cls.__dict__.items(): + if isinstance(field, FieldInfo): + if not field.annotation: + field.annotation = cls.__annotations__.get(name) + fields[name] = field + for name, field in cls.__annotations__.items(): + if name in fields: + continue + fields[name] = PydanticFieldInfo.from_annotation(field) + + for name, field in fields.items(): + _type = get_outer_type(field) + if not isinstance(field, FieldInfo) and hasattr(field, 'metadata') and len(field.metadata) > 0 and isinstance(field.metadata[0], FieldInfo): + field = field.metadata[0] + + if hasattr(field, 'field_info'): + field_info = field.field_info + else: + field_info = field + if getattr(field_info, "primary_key", None): + if issubclass(_type, str): + redisearch_field = ( + f"$.{name} AS {name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" + ) + else: + redisearch_field = cls.schema_for_type( + name, _type, field_info + ) + schema_parts.append(redisearch_field) + continue schema_parts.append( - cls.schema_for_type(json_path, name, "", _type, field.field_info) + cls.schema_for_type(json_path, name, "", _type, field_info) ) return schema_parts @@ -1843,6 +1980,13 @@ def schema_for_type( name_prefix = f"{name_prefix}_{name}" if name_prefix else name sub_fields = [] for embedded_name, field in typ.__fields__.items(): + if hasattr(field, 'field_info'): + field_info = field.field_info + elif hasattr(field, 'metadata') and len(field.metadata) > 0 and isinstance(field.metadata[0], FieldInfo): + field_info = field.metadata[0] + else: + field_info = field + if parent_is_container_type: # We'll store this value either as a JavaScript array, so # the correct JSONPath expression is to refer directly to @@ -1859,8 +2003,9 @@ def schema_for_type( path, embedded_name, name_prefix, - field.outer_type_, - field.field_info, + # field.annotation, + get_outer_type(field), + field_info, parent_type=typ, ) ) diff --git a/pyproject.toml b/pyproject.toml index c92dae15..5084e021 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "redis-om" -version = "0.2.2" +version = "0.3.0" description = "Object mappings, and more, for Redis." authors = ["Redis OSS "] maintainers = ["Redis OSS "] @@ -37,7 +37,7 @@ include=[ [tool.poetry.dependencies] python = ">=3.8,<4.0" redis = ">=3.5.3,<6.0.0" -pydantic = ">=1.10.2,<2.5.0" +pydantic = ">=1.10.2,<3.0.0" click = "^8.0.1" types-redis = ">=3.5.9,<5.0.0" python-ulid = "^1.0.3" diff --git a/tests/_compat.py b/tests/_compat.py index c21b47d2..f0a0b28a 100644 --- a/tests/_compat.py +++ b/tests/_compat.py @@ -1,7 +1,10 @@ -from aredis_om._compat import PYDANTIC_V2 +from aredis_om._compat import use_pydantic_2_plus, PYDANTIC_V2 +if not use_pydantic_2_plus() and PYDANTIC_V2: + from pydantic.v1 import ValidationError, EmailStr +elif PYDANTIC_V2: + from pydantic import ValidationError + from pydantic import EmailStr -if PYDANTIC_V2: - from pydantic.v1 import EmailStr, ValidationError else: from pydantic import EmailStr, ValidationError diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 38ca18e2..aa1e496b 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -388,7 +388,10 @@ def test_validates_required_fields(m): # Raises ValidationError: last_name is required # TODO: Test the error value with pytest.raises(ValidationError): - m.Member(id=0, first_name="Andrew", zipcode="97086", join_date=today) + try: + m.Member(id=0, first_name="Andrew", zipcode="97086", join_date=today) + except Exception as e: + raise e; def test_validates_field(m): @@ -581,6 +584,7 @@ class Address(m.BaseHashModel): with pytest.raises(RedisModelError): class InvalidMember(m.BaseHashModel): + name: str = Field(index=True) address: Address @@ -728,7 +732,7 @@ class Address(m.BaseHashModel): # We need to build the key prefix because it will differ based on whether # these tests were copied into the tests_sync folder and unasynce'd. key_prefix = Address.make_key(Address._meta.primary_key_pattern.format(pk="")) - + schema = Address.redisearch_schema() assert ( Address.redisearch_schema() == f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC" diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 8fec6c0a..8555b841 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -50,12 +50,12 @@ class Note(EmbeddedJsonModel): class Address(EmbeddedJsonModel): address_line_1: str - address_line_2: Optional[str] + address_line_2: Optional[str] = None city: str = Field(index=True) state: str country: str postal_code: str = Field(index=True) - note: Optional[Note] + note: Optional[Note] = None class Item(EmbeddedJsonModel): price: decimal.Decimal @@ -68,16 +68,16 @@ class Order(EmbeddedJsonModel): class Member(BaseJsonModel): first_name: str = Field(index=True) last_name: str = Field(index=True) - email: str = Field(index=True) + email: Optional[str] = Field(index=True, default=None) join_date: datetime.date - age: int = Field(index=True) + age: Optional[int] = Field(index=True, default=None) bio: Optional[str] = Field(index=True, full_text_search=True, default="") # Creates an embedded model. address: Address # Creates an embedded list of models. - orders: Optional[List[Order]] + orders: Optional[List[Order]] = None await Migrator().run() @@ -88,13 +88,16 @@ class Member(BaseJsonModel): @pytest.fixture() def address(m): - yield m.Address( - address_line_1="1 Main St.", - city="Portland", - state="OR", - country="USA", - postal_code=11111, - ) + try: + yield m.Address( + address_line_1="1 Main St.", + city="Portland", + state="OR", + country="USA", + postal_code='11111', + ) + except Exception as e: + raise e @pytest_asyncio.fixture() diff --git a/tests/test_oss_redis_features.py b/tests/test_oss_redis_features.py index 4d5b0913..47ebe47f 100644 --- a/tests/test_oss_redis_features.py +++ b/tests/test_oss_redis_features.py @@ -166,7 +166,9 @@ async def test_saves_many(m): result = await m.Member.add(members) assert result == [member1, member2] - assert await m.Member.get(pk=member1.pk) == member1 + m1_rematerialized = await m.Member.get(pk=member1.pk) + + assert m1_rematerialized == member1 assert await m.Member.get(pk=member2.pk) == member2 From 69cb9ef0089796137f1c32f24c9d40f7deb8993a Mon Sep 17 00:00:00 2001 From: slorello89 Date: Thu, 4 Apr 2024 15:45:06 -0400 Subject: [PATCH 02/11] couple more tests --- aredis_om/_compat.py | 12 ++++++------ aredis_om/model/model.py | 2 ++ tests/_compat.py | 2 +- tests/test_json_model.py | 31 +++++++++++++++++++++++++++++-- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/aredis_om/_compat.py b/aredis_om/_compat.py index 0fe4f486..de10ba1b 100644 --- a/aredis_om/_compat.py +++ b/aredis_om/_compat.py @@ -89,11 +89,11 @@ def __hash__(self) -> int: else: - from pydantic.v1 import BaseModel, validator - from pydantic.v1.fields import FieldInfo, ModelField, Undefined, UndefinedType - from pydantic.v1.json import ENCODERS_BY_TYPE - from pydantic.v1.main import ModelMetaclass, validate_model - from pydantic.v1.typing import NoArgAnyCallable - from pydantic.v1.utils import Representation + from pydantic import BaseModel, validator + from pydantic.fields import FieldInfo, ModelField, Undefined, UndefinedType + from pydantic.json import ENCODERS_BY_TYPE + from pydantic.main import ModelMetaclass, validate_model + from pydantic.typing import NoArgAnyCallable + from pydantic.utils import Representation def use_pydantic_2_plus(): return False diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 24d59b04..239dc299 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -80,6 +80,8 @@ def get_outer_type(field): return field.outer_type_ elif isinstance(field.annotation, type) or is_supported_container_type(field.annotation): return field.annotation + elif not isinstance(field.annotation.__args__[0], type): + return field.annotation.__args__[0].__origin__ else: return field.annotation.__args__[0] diff --git a/tests/_compat.py b/tests/_compat.py index f0a0b28a..9d360116 100644 --- a/tests/_compat.py +++ b/tests/_compat.py @@ -4,7 +4,7 @@ from pydantic.v1 import ValidationError, EmailStr elif PYDANTIC_V2: from pydantic import ValidationError - from pydantic import EmailStr + from pydantic import EmailStr, PositiveInt else: from pydantic import EmailStr, ValidationError diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 8555b841..1a828739 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -7,6 +7,7 @@ from collections import namedtuple from typing import Dict, List, Optional, Set from unittest import mock +from tests._compat import EmailStr, PositiveInt import pytest import pytest_asyncio @@ -68,9 +69,9 @@ class Order(EmbeddedJsonModel): class Member(BaseJsonModel): first_name: str = Field(index=True) last_name: str = Field(index=True) - email: Optional[str] = Field(index=True, default=None) + email: Optional[EmailStr] = Field(index=True, default=None) join_date: datetime.date - age: Optional[int] = Field(index=True, default=None) + age: Optional[PositiveInt] = Field(index=True, default=None) bio: Optional[str] = Field(index=True, full_text_search=True, default="") # Creates an embedded model. @@ -136,6 +137,32 @@ async def members(address, m): yield member1, member2, member3 +@py_test_mark_asyncio +async def test_validate_bad_email(address, m): + # Raises ValidationError as email is malformed + with pytest.raises(ValidationError): + m.Member( + first_name="Andrew", + last_name="Brookins", + zipcode="97086", + join_date=today, + email = 'foobarbaz' + ) + +@py_test_mark_asyncio +async def test_validate_bad_age(address, m): + # Raises ValidationError as email is malformed + with pytest.raises(ValidationError): + m.Member( + first_name="Andrew", + last_name="Brookins", + zipcode="97086", + join_date=today, + email='foo@bar.com', + address=address, + age=-5 + ) + @py_test_mark_asyncio async def test_validates_required_fields(address, m): # Raises ValidationError address is required From 69c78918591f22b1183a6e369de19e52d590899f Mon Sep 17 00:00:00 2001 From: slorello89 Date: Thu, 4 Apr 2024 16:52:23 -0400 Subject: [PATCH 03/11] lintining --- aredis_om/_compat.py | 24 ++++---- aredis_om/model/model.py | 118 +++++++++++++++++++++++++-------------- tests/_compat.py | 8 +-- tests/test_hash_model.py | 2 +- tests/test_json_model.py | 13 +++-- 5 files changed, 99 insertions(+), 66 deletions(-) diff --git a/aredis_om/_compat.py b/aredis_om/_compat.py index de10ba1b..07dc2824 100644 --- a/aredis_om/_compat.py +++ b/aredis_om/_compat.py @@ -1,5 +1,3 @@ -from pydantic.version import VERSION as PYDANTIC_VERSION -from typing_extensions import Annotated, Literal, get_args, get_origin from dataclasses import dataclass, is_dataclass from typing import ( Any, @@ -16,26 +14,28 @@ Union, ) +from pydantic.version import VERSION as PYDANTIC_VERSION +from typing_extensions import Annotated, Literal, get_args, get_origin PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") if PYDANTIC_V2: + def use_pydantic_2_plus(): return True - from pydantic import BaseModel, validator + + from pydantic import BaseModel, TypeAdapter + from pydantic import ValidationError as ValidationError + from pydantic import validator from pydantic._internal._model_construction import ModelMetaclass - from pydantic.fields import FieldInfo - from pydantic_core import PydanticUndefined as Undefined, PydanticUndefinedType as UndefinedType + from pydantic._internal._repr import Representation from pydantic.deprecated.json import ENCODERS_BY_TYPE - from pydantic import TypeAdapter - from pydantic import ValidationError as ValidationError - - + from pydantic.fields import FieldInfo from pydantic.v1.main import validate_model - from pydantic.v1.typing import NoArgAnyCallable - from pydantic._internal._repr import Representation + from pydantic_core import PydanticUndefined as Undefined + from pydantic_core import PydanticUndefinedType as UndefinedType @dataclass class ModelField: @@ -87,7 +87,6 @@ def __hash__(self) -> int: # ModelField to its JSON Schema. return id(self) - else: from pydantic import BaseModel, validator from pydantic.fields import FieldInfo, ModelField, Undefined, UndefinedType @@ -95,5 +94,6 @@ def __hash__(self) -> int: from pydantic.main import ModelMetaclass, validate_model from pydantic.typing import NoArgAnyCallable from pydantic.utils import Representation + def use_pydantic_2_plus(): return False diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 239dc299..754e76f8 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -76,15 +76,18 @@ def get_outer_type(field): - if hasattr(field, 'outer_type_'): + if hasattr(field, "outer_type_"): return field.outer_type_ - elif isinstance(field.annotation, type) or is_supported_container_type(field.annotation): + elif isinstance(field.annotation, type) or is_supported_container_type( + field.annotation + ): return field.annotation - elif not isinstance(field.annotation.__args__[0], type): - return field.annotation.__args__[0].__origin__ + # elif not isinstance(field.annotation.__args__[0], type): + # return field.annotation.__args__[0].__origin__ else: return field.annotation.__args__[0] + class RedisModelError(Exception): """Raised when a problem exists in the definition of a RedisModel.""" @@ -116,7 +119,9 @@ def __str__(self): return str(self.name) -ExpressionOrModelField = Union["Expression", "NegatedExpression", ModelField, PydanticFieldInfo] +ExpressionOrModelField = Union[ + "Expression", "NegatedExpression", ModelField, PydanticFieldInfo +] def embedded(cls): @@ -140,8 +145,8 @@ def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any if "__" in field_name: obj = model for sub_field in field_name.split("__"): - if not isinstance(obj, ModelMeta) and hasattr(obj, 'field'): - obj = getattr(obj, 'field').annotation + if not isinstance(obj, ModelMeta) and hasattr(obj, "field"): + obj = getattr(obj, "field").annotation if not hasattr(obj, sub_field): raise QuerySyntaxError( @@ -496,7 +501,9 @@ def validate_sort_fields(self, sort_fields: List[str]): f"does not exist on the model {self.model}" ) field_proxy = getattr(self.model, field_name) - if isinstance(field_proxy.field, FieldInfo) or isinstance(field_proxy.field, PydanticFieldInfo): + if isinstance(field_proxy.field, FieldInfo) or isinstance( + field_proxy.field, PydanticFieldInfo + ): field_info = field_proxy.field else: field_info = field_proxy.field.field_info @@ -510,7 +517,7 @@ def validate_sort_fields(self, sort_fields: List[str]): @staticmethod def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes: - if not hasattr(field, 'field_info'): + if not hasattr(field, "field_info"): field_info = field else: field_info = field.field_info @@ -528,6 +535,9 @@ def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes field_type = outer_type_or_annotation(field) + if not isinstance(field_type, type): + field_type = field_type.__origin__ + # TODO: GEO fields container_type = get_origin(field_type) @@ -1270,10 +1280,11 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 for base_candidate in bases: if hasattr(base_candidate, field_name): inner_field = getattr(base_candidate, field_name) - if hasattr(inner_field, 'field') and isinstance(getattr(inner_field, 'field'), FieldInfo): - field.metadata.append(getattr(inner_field, 'field')) - field = getattr(inner_field, 'field') - + if hasattr(inner_field, "field") and isinstance( + getattr(inner_field, "field"), FieldInfo + ): + field.metadata.append(getattr(inner_field, "field")) + field = getattr(inner_field, "field") if not field.alias: field.alias = field_name @@ -1288,9 +1299,11 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 # Check if this is our FieldInfo version with extended ORM metadata. # if isinstance(field.field_info, FieldInfo): field_info = None - if hasattr(field, 'field_info') and isinstance(field.field_info, FieldInfo): + if hasattr(field, "field_info") and isinstance(field.field_info, FieldInfo): field_info = field.field_info - elif field_name in attrs and isinstance(attrs.__getitem__(field_name), FieldInfo): + elif field_name in attrs and isinstance( + attrs.__getitem__(field_name), FieldInfo + ): field_info = attrs.__getitem__(field_name) field.field_info = field_info @@ -1344,9 +1357,9 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 def outer_type_or_annotation(field): - if hasattr(field, 'outer_type_'): + if hasattr(field, "outer_type_"): return field.outer_type_ - elif not hasattr(field.annotation, '__args__'): + elif not hasattr(field.annotation, "__args__"): if not isinstance(field.annotation, type): raise AttributeError(f"could not extract outer type from field {field}") return field.annotation @@ -1355,7 +1368,7 @@ def outer_type_or_annotation(field): class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): -# class RedisModel(BaseModel, abc.ABC): + # class RedisModel(BaseModel, abc.ABC): pk: Optional[str] = Field(default=None, primary_key=True) # pk: Optional[str] = Field(default=None, primary_key=True) @@ -1376,7 +1389,7 @@ def __lt__(self, other): def key(self): """Return the Redis key for this model.""" - if hasattr(self._meta.primary_key.field, 'name'): + if hasattr(self._meta.primary_key.field, "name"): pk = getattr(self, self._meta.primary_key.field.name) else: pk = getattr(self, self._meta.primary_key.name) @@ -1427,8 +1440,13 @@ def validate_primary_key(cls): """Check for a primary key. We need one (and only one).""" primary_keys = 0 for name, field in cls.__fields__.items(): - if not hasattr(field, 'field_info'): - if not isinstance(field, FieldInfo) and hasattr(field, 'metadata') and len(field.metadata) > 0 and isinstance(field.metadata[0], FieldInfo): + if not hasattr(field, "field_info"): + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): field_info = field.metadata[0] else: field_info = field @@ -1568,6 +1586,7 @@ def redisearch_schema(cls): def check(self): """Run all validations.""" from pydantic.version import VERSION as PYDANTIC_VERSION + PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") if not PYDANTIC_V2: *_, validation_error = validate_model(self.__class__, self.__dict__) @@ -1579,7 +1598,7 @@ class HashModel(RedisModel, abc.ABC): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - if hasattr(cls, '__annotations__'): + if hasattr(cls, "__annotations__"): for name, field_type in cls.__annotations__.items(): origin = get_origin(field_type) for typ in (Set, Mapping, List): @@ -1589,8 +1608,12 @@ def __init_subclass__(cls, **kwargs): f" or mapping fields. Field: {name}" ) if isinstance(field_type, type) and issubclass(field_type, RedisModel): - raise RedisModelError(f"HashModels cannot index embedded model fields. Field: {name}") - elif isinstance(field_type, type) and dataclasses.is_dataclass(field_type): + raise RedisModelError( + f"HashModels cannot index embedded model fields. Field: {name}" + ) + elif isinstance(field_type, type) and dataclasses.is_dataclass( + field_type + ): raise RedisModelError( f"HashModels cannot index dataclass fields. Field: {name}" ) @@ -1698,10 +1721,15 @@ def schema_for_fields(cls): _type = outer_type_or_annotation(field) is_subscripted_type = get_origin(_type) - if not isinstance(field, FieldInfo) and hasattr(field, 'metadata') and len(field.metadata) > 0 and isinstance(field.metadata[0], FieldInfo): + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): field = field.metadata[0] - if not hasattr(field, 'field_info'): + if not hasattr(field, "field_info"): field_info = field else: field_info = field.field_info @@ -1712,9 +1740,7 @@ def schema_for_fields(cls): f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" ) else: - redisearch_field = cls.schema_for_type( - name, _type, field_info - ) + redisearch_field = cls.schema_for_type(name, _type, field_info) schema_parts.append(redisearch_field) elif getattr(field_info, "index", None) is True: schema_parts.append(cls.schema_for_type(name, _type, field_info)) @@ -1730,9 +1756,7 @@ def schema_for_fields(cls): log.warning("Model %s defined an empty list field: %s", cls, name) continue embedded_cls = embedded_cls[0] - schema_parts.append( - cls.schema_for_type(name, embedded_cls, field_info) - ) + schema_parts.append(cls.schema_for_type(name, embedded_cls, field_info)) elif issubclass(_type, RedisModel): schema_parts.append(cls.schema_for_type(name, _type, field_info)) return schema_parts @@ -1875,7 +1899,7 @@ def schema_for_fields(cls): for name, field in cls.__dict__.items(): if isinstance(field, FieldInfo): if not field.annotation: - field.annotation = cls.__annotations__.get(name) + field.annotation = cls.__annotations__.get(name) fields[name] = field for name, field in cls.__annotations__.items(): if name in fields: @@ -1884,22 +1908,23 @@ def schema_for_fields(cls): for name, field in fields.items(): _type = get_outer_type(field) - if not isinstance(field, FieldInfo) and hasattr(field, 'metadata') and len(field.metadata) > 0 and isinstance(field.metadata[0], FieldInfo): + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): field = field.metadata[0] - if hasattr(field, 'field_info'): + if hasattr(field, "field_info"): field_info = field.field_info else: field_info = field if getattr(field_info, "primary_key", None): if issubclass(_type, str): - redisearch_field = ( - f"$.{name} AS {name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" - ) + redisearch_field = f"$.{name} AS {name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" else: - redisearch_field = cls.schema_for_type( - name, _type, field_info - ) + redisearch_field = cls.schema_for_type(name, _type, field_info) schema_parts.append(redisearch_field) continue schema_parts.append( @@ -1982,9 +2007,13 @@ def schema_for_type( name_prefix = f"{name_prefix}_{name}" if name_prefix else name sub_fields = [] for embedded_name, field in typ.__fields__.items(): - if hasattr(field, 'field_info'): + if hasattr(field, "field_info"): field_info = field.field_info - elif hasattr(field, 'metadata') and len(field.metadata) > 0 and isinstance(field.metadata[0], FieldInfo): + elif ( + hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): field_info = field.metadata[0] else: field_info = field @@ -2031,6 +2060,9 @@ def schema_for_type( "See docs: TODO" ) + if not isinstance(typ, type): + typ = field_info.annotation.__args__[0].__origin__ + # TODO: GEO field if is_vector and vector_options: schema = f"{path} AS {index_field_name} {vector_options.schema}" diff --git a/tests/_compat.py b/tests/_compat.py index 9d360116..1cd55bf2 100644 --- a/tests/_compat.py +++ b/tests/_compat.py @@ -1,10 +1,10 @@ -from aredis_om._compat import use_pydantic_2_plus, PYDANTIC_V2 +from aredis_om._compat import PYDANTIC_V2, use_pydantic_2_plus + if not use_pydantic_2_plus() and PYDANTIC_V2: - from pydantic.v1 import ValidationError, EmailStr + from pydantic.v1 import EmailStr, ValidationError elif PYDANTIC_V2: - from pydantic import ValidationError - from pydantic import EmailStr, PositiveInt + from pydantic import EmailStr, PositiveInt, ValidationError else: from pydantic import EmailStr, ValidationError diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index aa1e496b..f0993a94 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -391,7 +391,7 @@ def test_validates_required_fields(m): try: m.Member(id=0, first_name="Andrew", zipcode="97086", join_date=today) except Exception as e: - raise e; + raise e def test_validates_field(m): diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 1a828739..bcd95af7 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -7,7 +7,6 @@ from collections import namedtuple from typing import Dict, List, Optional, Set from unittest import mock -from tests._compat import EmailStr, PositiveInt import pytest import pytest_asyncio @@ -25,7 +24,7 @@ # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. from redis_om import has_redis_json -from tests._compat import ValidationError +from tests._compat import EmailStr, PositiveInt, ValidationError from .conftest import py_test_mark_asyncio @@ -95,7 +94,7 @@ def address(m): city="Portland", state="OR", country="USA", - postal_code='11111', + postal_code="11111", ) except Exception as e: raise e @@ -146,9 +145,10 @@ async def test_validate_bad_email(address, m): last_name="Brookins", zipcode="97086", join_date=today, - email = 'foobarbaz' + email="foobarbaz", ) + @py_test_mark_asyncio async def test_validate_bad_age(address, m): # Raises ValidationError as email is malformed @@ -158,11 +158,12 @@ async def test_validate_bad_age(address, m): last_name="Brookins", zipcode="97086", join_date=today, - email='foo@bar.com', + email="foo@bar.com", address=address, - age=-5 + age=-5, ) + @py_test_mark_asyncio async def test_validates_required_fields(address, m): # Raises ValidationError address is required From cc6b4b77d88df300927225d48ee8127bc72251cb Mon Sep 17 00:00:00 2001 From: slorello89 Date: Thu, 4 Apr 2024 16:55:33 -0400 Subject: [PATCH 04/11] more linting --- tests/test_hash_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index f0993a94..028de57c 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -732,7 +732,6 @@ class Address(m.BaseHashModel): # We need to build the key prefix because it will differ based on whether # these tests were copied into the tests_sync folder and unasynce'd. key_prefix = Address.make_key(Address._meta.primary_key_pattern.format(pk="")) - schema = Address.redisearch_schema() assert ( Address.redisearch_schema() == f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC" From dc3107c442d2b624594060d0028ee2945d05219d Mon Sep 17 00:00:00 2001 From: slorello89 Date: Fri, 5 Apr 2024 07:56:23 -0400 Subject: [PATCH 05/11] bumpying mypy version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5084e021..b9ed08ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ more-itertools = ">=8.14,<11.0" setuptools = {version = "^69.2.0", markers = "python_version >= '3.12'"} [tool.poetry.dev-dependencies] -mypy = "^0.982" +mypy = "^1.9.0" pytest = "^8.0.2" ipdb = "^0.13.9" black = "^24.2" From 1571c4a4c2ce9a8661cb0a0a62d3ff3f4c121f3b Mon Sep 17 00:00:00 2001 From: slorello89 Date: Fri, 5 Apr 2024 09:38:18 -0400 Subject: [PATCH 06/11] more linting --- aredis_om/model/encoders.py | 2 +- aredis_om/model/model.py | 50 +++++++++++++++++++++---------------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/aredis_om/model/encoders.py b/aredis_om/model/encoders.py index 2f90e481..f097a35d 100644 --- a/aredis_om/model/encoders.py +++ b/aredis_om/model/encoders.py @@ -68,7 +68,7 @@ def jsonable_encoder( if exclude is not None and not isinstance(exclude, (set, dict)): exclude = set(exclude) - if isinstance(obj, BaseModel): + if isinstance(obj, BaseModel) and hasattr(obj, "__config__"): encoder = getattr(obj.__config__, "json_encoders", {}) if custom_encoder: encoder.update(custom_encoder) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 754e76f8..1b90312b 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -21,8 +21,9 @@ Type, TypeVar, Union, - no_type_check, ) +from typing import get_args as typing_get_args +from typing import no_type_check from more_itertools import ichunked from redis.commands.json.path import Path @@ -82,8 +83,6 @@ def get_outer_type(field): field.annotation ): return field.annotation - # elif not isinstance(field.annotation.__args__[0], type): - # return field.annotation.__args__[0].__origin__ else: return field.annotation.__args__[0] @@ -156,7 +155,7 @@ def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any obj = getattr(obj, sub_field) return - if field_name not in model.__fields__: + if field_name not in model.__fields__: # type: ignore raise QuerySyntaxError( f"The field {field_name} does not exist on the model {model.__name__}" ) @@ -495,7 +494,7 @@ def validate_sort_fields(self, sort_fields: List[str]): field_name = sort_field.lstrip("-") if self.knn and field_name == self.knn.score_field: continue - if field_name not in self.model.__fields__: + if field_name not in self.model.__fields__: # type: ignore raise QueryNotSupportedError( f"You tried sort by {field_name}, but that field " f"does not exist on the model {self.model}" @@ -516,7 +515,11 @@ def validate_sort_fields(self, sort_fields: List[str]): return sort_fields @staticmethod - def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes: + def resolve_field_type( + field: Union[ModelField, PydanticFieldInfo], op: Operators + ) -> RediSearchFieldTypes: + field_info: Union[FieldInfo, ModelField, PydanticFieldInfo] + if not hasattr(field, "field_info"): field_info = field else: @@ -527,7 +530,7 @@ def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes fts = getattr(field_info, "full_text_search", None) if fts is not True: # Could be PydanticUndefined raise QuerySyntaxError( - f"You tried to do a full-text search on the field '{field.name}', " + f"You tried to do a full-text search on the field '{field.alias}', " f"but the field is not indexed for full-text search. Use the " f"full_text_search=True option. Docs: {ERRORS_URL}#E3" ) @@ -1144,27 +1147,27 @@ def Field( default: Any = Undefined, *, default_factory: Optional[NoArgAnyCallable] = None, - alias: str = None, - title: str = None, - description: str = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, exclude: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any ] = None, include: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any ] = None, - const: bool = None, - gt: float = None, - ge: float = None, - lt: float = None, - le: float = None, - multiple_of: float = None, - min_items: int = None, - max_items: int = None, - min_length: int = None, - max_length: int = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, allow_mutation: bool = True, - regex: str = None, + regex: Optional[str] = None, primary_key: bool = False, sortable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, @@ -2060,8 +2063,11 @@ def schema_for_type( "See docs: TODO" ) + # For more complicated compound validators (e.g. PositiveInt), we might get a _GenericAlias rather than + # a proper type, we can pull the type information from the origin of the first argument. if not isinstance(typ, type): - typ = field_info.annotation.__args__[0].__origin__ + type_args = typing_get_args(field_info.annotation) + typ = type_args[0].__origin__ # TODO: GEO field if is_vector and vector_options: From 5a01b1d0a8e7a97a52eb920d1e703d2377655471 Mon Sep 17 00:00:00 2001 From: slorello89 Date: Fri, 5 Apr 2024 10:11:47 -0400 Subject: [PATCH 07/11] adding tests for #591 --- tests/test_hash_model.py | 25 ++++++++++++++++++++++++- tests/test_json_model.py | 25 ++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 028de57c..b682a753 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -5,7 +5,7 @@ import datetime import decimal from collections import namedtuple -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Union from unittest import mock import pytest @@ -807,3 +807,26 @@ async def test_count(members, m): m.Member.first_name == "Kim", m.Member.last_name == "Brookins" ).count() assert actual_count == 1 + + +@py_test_mark_asyncio +async def test_type_with_union(members, m): + class TypeWithUnion(m.BaseHashModel): + field: Union[str, int] + + twu_str = TypeWithUnion(field="hello world") + res = await twu_str.save() + assert res.pk == twu_str.pk + twu_str_rematerialized = await TypeWithUnion.get(twu_str.pk) + assert ( + isinstance(twu_str_rematerialized.field, str) + and twu_str_rematerialized.pk == twu_str.pk + ) + + twu_int = TypeWithUnion(field=42) + await twu_int.save() + twu_int_rematerialized = await TypeWithUnion.get(twu_int.pk) + + # Note - we will not be able to automatically serialize an int back to this union type, + # since as far as we know from Redis this item is a string + assert twu_int_rematerialized.pk == twu_int.pk diff --git a/tests/test_json_model.py b/tests/test_json_model.py index bcd95af7..ebc6d082 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -5,7 +5,7 @@ import datetime import decimal from collections import namedtuple -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Union from unittest import mock import pytest @@ -880,3 +880,26 @@ async def test_count(members, m): m.Member.first_name == "Kim", m.Member.last_name == "Brookins" ).count() assert actual_count == 1 + + +@py_test_mark_asyncio +async def test_type_with_union(members, m): + class TypeWithUnion(m.BaseJsonModel): + field: Union[str, int] + + twu_str = TypeWithUnion(field="hello world") + res = await twu_str.save() + assert res.pk == twu_str.pk + twu_str_rematerialized = await TypeWithUnion.get(twu_str.pk) + assert ( + isinstance(twu_str_rematerialized.field, str) + and twu_str_rematerialized.pk == twu_str.pk + ) + + twu_int = TypeWithUnion(field=42) + await twu_int.save() + twu_int_rematerialized = await TypeWithUnion.get(twu_int.pk) + assert ( + isinstance(twu_int_rematerialized.field, int) + and twu_int_rematerialized.pk == twu_int.pk + ) From 10989427a0f388b7d269a9122b9ab0abe4e3c163 Mon Sep 17 00:00:00 2001 From: slorello89 Date: Fri, 5 Apr 2024 13:35:25 -0400 Subject: [PATCH 08/11] adding tests with uuid --- tests/test_hash_model.py | 11 +++++++++++ tests/test_json_model.py | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index b682a753..f7aee626 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -4,6 +4,7 @@ import dataclasses import datetime import decimal +import uuid from collections import namedtuple from typing import Dict, List, Optional, Set, Union from unittest import mock @@ -830,3 +831,13 @@ class TypeWithUnion(m.BaseHashModel): # Note - we will not be able to automatically serialize an int back to this union type, # since as far as we know from Redis this item is a string assert twu_int_rematerialized.pk == twu_int.pk + + +@py_test_mark_asyncio +async def test_type_with_uuid(): + class TypeWithUuid(HashModel): + uuid: uuid.UUID + + item = TypeWithUuid(uuid=uuid.uuid4()) + + await item.save() diff --git a/tests/test_json_model.py b/tests/test_json_model.py index ebc6d082..eefcdf84 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -4,6 +4,7 @@ import dataclasses import datetime import decimal +import uuid from collections import namedtuple from typing import Dict, List, Optional, Set, Union from unittest import mock @@ -903,3 +904,13 @@ class TypeWithUnion(m.BaseJsonModel): isinstance(twu_int_rematerialized.field, int) and twu_int_rematerialized.pk == twu_int.pk ) + + +@py_test_mark_asyncio +async def test_type_with_uuid(): + class TypeWithUuid(JsonModel): + uuid: uuid.UUID + + item = TypeWithUuid(uuid=uuid.uuid4()) + + await item.save() From 37bf65220476d7035bbf81239c56ea35e474efab Mon Sep 17 00:00:00 2001 From: slorello89 Date: Fri, 5 Apr 2024 13:53:20 -0400 Subject: [PATCH 09/11] readme fixes --- README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 568b937e..c8456a83 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ Check out this example of modeling customer data with Redis OM. First, we create import datetime from typing import Optional -from pydantic.v1 import EmailStr +from pydantic import EmailStr from redis_om import HashModel @@ -104,7 +104,7 @@ class Customer(HashModel): email: EmailStr join_date: datetime.date age: int - bio: Optional[str] + bio: Optional[str] = None ``` Now that we have a `Customer` model, let's use it to save customer data to Redis. @@ -113,7 +113,7 @@ Now that we have a `Customer` model, let's use it to save customer data to Redis import datetime from typing import Optional -from pydantic.v1 import EmailStr +from pydantic import EmailStr from redis_om import HashModel @@ -124,7 +124,7 @@ class Customer(HashModel): email: EmailStr join_date: datetime.date age: int - bio: Optional[str] + bio: Optional[str] = None # First, we create a new `Customer` object: @@ -168,7 +168,7 @@ For example, because we used the `EmailStr` type for the `email` field, we'll ge import datetime from typing import Optional -from pydantic.v1 import EmailStr, ValidationError +from pydantic import EmailStr, ValidationError from redis_om import HashModel @@ -179,7 +179,7 @@ class Customer(HashModel): email: EmailStr join_date: datetime.date age: int - bio: Optional[str] + bio: Optional[str] = None try: @@ -222,7 +222,7 @@ To show how this works, we'll make a small change to the `Customer` model we def import datetime from typing import Optional -from pydantic.v1 import EmailStr +from pydantic import EmailStr from redis_om import ( Field, @@ -237,7 +237,7 @@ class Customer(HashModel): email: EmailStr join_date: datetime.date age: int = Field(index=True) - bio: Optional[str] + bio: Optional[str] = None # Now, if we use this model with a Redis deployment that has the @@ -287,7 +287,7 @@ from redis_om import ( class Address(EmbeddedJsonModel): address_line_1: str - address_line_2: Optional[str] + address_line_2: Optional[str] = None city: str = Field(index=True) state: str = Field(index=True) country: str From 31d99acefe1ea6c568b1d5948a7690d754b378a1 Mon Sep 17 00:00:00 2001 From: slorello89 Date: Wed, 24 Apr 2024 14:15:27 -0400 Subject: [PATCH 10/11] addressing comments --- aredis_om/model/model.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 1b90312b..ec28fdbb 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -22,8 +22,7 @@ TypeVar, Union, ) -from typing import get_args as typing_get_args -from typing import no_type_check +from typing import get_args as typing_get_args, no_type_check from more_itertools import ichunked from redis.commands.json.path import Path @@ -1203,7 +1202,6 @@ def Field( vector_options=vector_options, **current_schema_extra, ) - # field_info._validate() return field_info @@ -1300,7 +1298,6 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 else: new_class.__annotations__[field_name] = ExpressionProxy # Check if this is our FieldInfo version with extended ORM metadata. - # if isinstance(field.field_info, FieldInfo): field_info = None if hasattr(field, "field_info") and isinstance(field.field_info, FieldInfo): field_info = field.field_info @@ -1371,9 +1368,7 @@ def outer_type_or_annotation(field): class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): - # class RedisModel(BaseModel, abc.ABC): pk: Optional[str] = Field(default=None, primary_key=True) - # pk: Optional[str] = Field(default=None, primary_key=True) Meta = DefaultMeta @@ -1646,11 +1641,6 @@ async def save( ) -> "Model": self.check() db = self._get_db(pipeline) - - # if hasattr(self,'model_fields_set'): - # dict = {k: v for k, v in self.dict().items() if k in self.model_fields_set} - # else: - # dict = self.dict() document = jsonable_encoder(self.dict()) # TODO: Wrap any Redis response errors in a custom exception? await db.hset(self.key(), mapping=document) From 8f46dd9813f9a61684a377b1d4fc1b5323aacca4 Mon Sep 17 00:00:00 2001 From: slorello89 Date: Thu, 2 May 2024 10:21:02 -0400 Subject: [PATCH 11/11] fixing typo in NOT_IN --- aredis_om/model/model.py | 2 +- tests/test_json_model.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index ec28fdbb..31c42bdb 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -688,7 +688,7 @@ def resolve_value( elif op is Operators.NOT_IN: # TODO: Implement NOT_IN, test this... expanded_value = cls.expand_tag_value(value) - result += "-(@{field_name}):{{{expanded_value}}}".format( + result += "-(@{field_name}:{{{expanded_value}}})".format( field_name=field_name, expanded_value=expanded_value ) diff --git a/tests/test_json_model.py b/tests/test_json_model.py index eefcdf84..55e8b0fa 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -454,6 +454,15 @@ async def test_in_query(members, m): ) assert actual == [member2, member1, member3] +@py_test_mark_asyncio +async def test_not_in_query(members, m): + member1, member2, member3 = members + actual = await ( + m.Member.find(m.Member.pk >> [member2.pk, member3.pk]) + .sort_by("age") + .all() + ) + assert actual == [ member1] @py_test_mark_asyncio async def test_update_query(members, m):