Skip to content

Commit 8df2fee

Browse files
denizalpaslanDeniz Alpaslanapetenchea
authored
add support for sort in collection.find function (#359)
* add support for sort in collection.find function * update sort parameter type. add SortValidationError as custom exception * Update arango/collection.py Co-authored-by: Alex Petenchea <[email protected]> * Update arango/collection.py Co-authored-by: Alex Petenchea <[email protected]> * update utils.py and collection.py to raise SortValidationError * update utils.py for build_sort_expression to accept Jsons or None * Update arango/collection.py --------- Co-authored-by: Deniz Alpaslan <[email protected]> Co-authored-by: Alex Petenchea <[email protected]>
1 parent a7ff90d commit 8df2fee

File tree

6 files changed

+86
-3
lines changed

6 files changed

+86
-3
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,6 @@ arango/version.py
124124

125125
# test results
126126
*_results.txt
127+
128+
# devcontainers
129+
.devcontainer

arango/collection.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@
5050
from arango.typings import Fields, Headers, Json, Jsons, Params
5151
from arango.utils import (
5252
build_filter_conditions,
53+
build_sort_expression,
5354
get_batches,
5455
get_doc_id,
5556
is_none_or_bool,
5657
is_none_or_int,
5758
is_none_or_str,
59+
validate_sort_parameters,
5860
)
5961

6062

@@ -753,6 +755,7 @@ def find(
753755
skip: Optional[int] = None,
754756
limit: Optional[int] = None,
755757
allow_dirty_read: bool = False,
758+
sort: Optional[Jsons] = None,
756759
) -> Result[Cursor]:
757760
"""Return all documents that match the given filters.
758761
@@ -764,23 +767,28 @@ def find(
764767
:type limit: int | None
765768
:param allow_dirty_read: Allow reads from followers in a cluster.
766769
:type allow_dirty_read: bool
770+
:param sort: Document sort parameters
771+
:type sort: Jsons | None
767772
:return: Document cursor.
768773
:rtype: arango.cursor.Cursor
769774
:raise arango.exceptions.DocumentGetError: If retrieval fails.
775+
:raise arango.exceptions.SortValidationError: If sort parameters are invalid.
770776
"""
771777
assert isinstance(filters, dict), "filters must be a dict"
772778
assert is_none_or_int(skip), "skip must be a non-negative int"
773779
assert is_none_or_int(limit), "limit must be a non-negative int"
780+
if sort:
781+
validate_sort_parameters(sort)
774782

775783
skip_val = skip if skip is not None else 0
776784
limit_val = limit if limit is not None else "null"
777785
query = f"""
778786
FOR doc IN @@collection
779787
{build_filter_conditions(filters)}
780788
LIMIT {skip_val}, {limit_val}
789+
{build_sort_expression(sort)}
781790
RETURN doc
782791
"""
783-
784792
bind_vars = {"@collection": self.name}
785793

786794
request = Request(

arango/exceptions.py

+7
Original file line numberDiff line numberDiff line change
@@ -1074,3 +1074,10 @@ class JWTRefreshError(ArangoClientError):
10741074

10751075
class JWTExpiredError(ArangoClientError):
10761076
"""JWT token has expired."""
1077+
1078+
1079+
###################################
1080+
# Parameter Validation Exceptions #
1081+
###################################
1082+
class SortValidationError(ArangoClientError):
1083+
"""Invalid sort parameters."""

arango/utils.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from contextlib import contextmanager
1212
from typing import Any, Iterator, Sequence, Union
1313

14-
from arango.exceptions import DocumentParseError
15-
from arango.typings import Json
14+
from arango.exceptions import DocumentParseError, SortValidationError
15+
from arango.typings import Json, Jsons
1616

1717

1818
@contextmanager
@@ -126,3 +126,42 @@ def build_filter_conditions(filters: Json) -> str:
126126
conditions.append(f"doc.{field} == {json.dumps(v)}")
127127

128128
return "FILTER " + " AND ".join(conditions)
129+
130+
131+
def validate_sort_parameters(sort: Sequence[Json]) -> bool:
132+
"""Validate sort parameters for an AQL query.
133+
134+
:param sort: Document sort parameters.
135+
:type sort: Sequence[Json]
136+
:return: Validation success.
137+
:rtype: bool
138+
:raise arango.exceptions.SortValidationError: If sort parameters are invalid.
139+
"""
140+
assert isinstance(sort, Sequence)
141+
for param in sort:
142+
if "sort_by" not in param or "sort_order" not in param:
143+
raise SortValidationError(
144+
"Each sort parameter must have 'sort_by' and 'sort_order'."
145+
)
146+
if param["sort_order"].upper() not in ["ASC", "DESC"]:
147+
raise SortValidationError("'sort_order' must be either 'ASC' or 'DESC'")
148+
return True
149+
150+
151+
def build_sort_expression(sort: Jsons | None) -> str:
152+
"""Build a sort condition for an AQL query.
153+
154+
:param sort: Document sort parameters.
155+
:type sort: Jsons | None
156+
:return: The complete AQL sort condition.
157+
:rtype: str
158+
"""
159+
if not sort:
160+
return ""
161+
162+
sort_chunks = []
163+
for sort_param in sort:
164+
chunk = f"doc.{sort_param['sort_by']} {sort_param['sort_order']}"
165+
sort_chunks.append(chunk)
166+
167+
return "SORT " + ", ".join(sort_chunks)

docs/document.rst

+6
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ Standard documents are managed via collection API wrapper:
103103
assert student['GPA'] == 3.6
104104
assert student['last'] == 'Kim'
105105

106+
# Retrieve one or more matching documents, sorted by a field.
107+
for student in students.find({'first': 'John'}, sort=[{'sort_by': 'GPA', 'sort_order': 'DESC'}]):
108+
assert student['_key'] == 'john'
109+
assert student['GPA'] == 3.6
110+
assert student['last'] == 'Kim'
111+
106112
# Retrieve a document by key.
107113
students.get('john')
108114

tests/test_document.py

+20
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,26 @@ def test_document_find(col, bad_col, docs):
11621162
# Set up test documents
11631163
col.import_bulk(docs)
11641164

1165+
# Test find with sort expression (single field)
1166+
found = list(col.find({}, sort=[{"sort_by": "text", "sort_order": "ASC"}]))
1167+
assert len(found) == 6
1168+
assert found[0]["text"] == "bar"
1169+
assert found[-1]["text"] == "foo"
1170+
1171+
# Test find with sort expression (multiple fields)
1172+
found = list(
1173+
col.find(
1174+
{},
1175+
sort=[
1176+
{"sort_by": "text", "sort_order": "ASC"},
1177+
{"sort_by": "val", "sort_order": "DESC"},
1178+
],
1179+
)
1180+
)
1181+
assert len(found) == 6
1182+
assert found[0]["val"] == 6
1183+
assert found[-1]["val"] == 1
1184+
11651185
# Test find (single match) with default options
11661186
found = list(col.find({"val": 2}))
11671187
assert len(found) == 1

0 commit comments

Comments
 (0)