Skip to content

fix issue with nested vector fields and python 3.13 issubclass changes #699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest ]
pyver: [ "3.9", "3.10", "3.11", "3.12", "pypy-3.9", "pypy-3.10" ]
pyver: [ "3.9", "3.10", "3.11", "3.12", "3.13", "pypy-3.9", "pypy-3.10" ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, we definitely needed to test against 3.13

redisstack: [ "latest" ]
fail-fast: false
services:
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,7 @@ tests_sync/
# spelling cruft
*.dic

.idea
.idea

# version files
.tool-versions
18 changes: 14 additions & 4 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,14 @@ def outer_type_or_annotation(field: FieldInfo):
return field.annotation.__args__[0] # type: ignore


def _is_numeric_type(type_: Type[Any]) -> bool:
args = get_args(type_)
try:
return any(issubclass(args[0], t) for t in NUMERIC_TYPES)
except TypeError:
return False


def should_index_field(field_info: Union[FieldInfo, PydanticFieldInfo]) -> bool:
# for vector, full text search, and sortable fields, we always have to index
# We could require the user to set index=True, but that would be a breaking change
Expand Down Expand Up @@ -2004,9 +2012,7 @@ def schema_for_type(
field_info, "vector_options", None
)
try:
is_vector = vector_options and any(
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
)
is_vector = vector_options and _is_numeric_type(typ)
except IndexError:
raise RedisModelError(
f"Vector field '{name}' must be annotated as a container type"
Expand Down Expand Up @@ -2104,7 +2110,11 @@ def schema_for_type(
# a proper type, we can pull the type information from the origin of the first argument.
if not isinstance(typ, type):
type_args = typing_get_args(field_info.annotation)
typ = type_args[0].__origin__
typ = (
getattr(type_args[0], "__origin__", type_args[0])
if type_args
else typ
)

# TODO: GEO field
if is_vector and vector_options:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redis-om"
version = "1.0.1-beta"
version = "1.0.2-beta"
description = "Object mappings, and more, for Redis."
authors = ["Redis OSS <[email protected]>"]
maintainers = ["Redis OSS <[email protected]>"]
Expand All @@ -22,6 +22,7 @@ classifiers = [
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: 3.13',
'Programming Language :: Python',
]
include=[
Expand Down
6 changes: 3 additions & 3 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ async def test_full_text_search_queries(members, m):
async def test_pagination_queries(members, m):
member1, member2, member3 = members

actual = await m.Member.find(m.Member.last_name == "Brookins").page()
actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("id").page()

assert actual == [member1, member2]

actual = await m.Member.find().page(1, 1)
actual = await m.Member.find().sort_by("id").page(1, 1)

assert actual == [member2]

actual = await m.Member.find().page(0, 1)
actual = await m.Member.find().sort_by("id").page(0, 1)

assert actual == [member1]

Expand Down
45 changes: 43 additions & 2 deletions tests/test_knn_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,24 @@ class Meta:

class Member(BaseJsonModel, index=True):
name: str
embeddings: list[list[float]] = Field([], vector_options=vector_field_options)
embeddings: list[float] = Field([], vector_options=vector_field_options)
embeddings_score: Optional[float] = None

await Migrator().run()

return Member


@pytest_asyncio.fixture
async def n(key_prefix, redis):
class BaseJsonModel(JsonModel, abc.ABC):
class Meta:
global_key_prefix = key_prefix
database = redis

class Member(BaseJsonModel, index=True):
name: str
nested: list[list[float]] = Field([], vector_options=vector_field_options)
embeddings_score: Optional[float] = None

await Migrator().run()
Expand All @@ -45,7 +62,7 @@ def to_bytes(vectors: list[float]) -> bytes:
async def test_vector_field(m: Type[JsonModel]):
# Create a new instance of the Member model
vectors = [0.3 for _ in range(DIMENSIONS)]
member = m(name="seth", embeddings=[vectors])
member = m(name="seth", embeddings=vectors)

# Save the member to Redis
await member.save()
Expand All @@ -63,3 +80,27 @@ async def test_vector_field(m: Type[JsonModel]):

assert len(members) == 1
assert members[0].embeddings_score is not None


@py_test_mark_asyncio
async def test_nested_vector_field(n: Type[JsonModel]):
# Create a new instance of the Member model
vectors = [0.3 for _ in range(DIMENSIONS)]
member = n(name="seth", nested=[vectors])

# Save the member to Redis
await member.save()

knn = KNNExpression(
k=1,
vector_field=n.nested,
score_field=n.embeddings_score,
reference_vector=to_bytes(vectors),
)

query = n.find(knn=knn)

members = await query.all()

assert len(members) == 1
assert members[0].embeddings_score is not None