Skip to content

Commit ff1f016

Browse files
committed
fix: fix wrong total_count when using distinct on m2m/o2m relationships
Fix #792
1 parent 407d80a commit ff1f016

File tree

2 files changed

+160
-2
lines changed

2 files changed

+160
-2
lines changed

strawberry_django/pagination.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ def apply_window_pagination(
249249
def remove_window_pagination(queryset: _QS) -> _QS:
250250
"""Remove pagination window functions from a queryset.
251251
252-
Utility function to remove the pagination `WHERE` clause added by
253-
the `apply_window_pagination` function.
252+
Utility function to remove the pagination `WHERE` clause and annotations
253+
added by the `apply_window_pagination` function.
254254
255255
Args:
256256
----
@@ -263,6 +263,11 @@ def remove_window_pagination(queryset: _QS) -> _QS:
263263
for child in queryset.query.where.children
264264
if (not hasattr(child, "lhs") or not isinstance(child.lhs, _PaginationWindow))
265265
]
266+
queryset.query.annotations = { # type: ignore
267+
key: value
268+
for key, value in queryset.query.annotations.items()
269+
if not isinstance(value, _PaginationWindow)
270+
}
266271
return queryset
267272

268273

@@ -278,6 +283,13 @@ def get_total_count(queryset: QuerySet) -> int:
278283
results = queryset._result_cache # type: ignore
279284

280285
if results:
286+
# If the queryset has DISTINCT enabled, the _strawberry_total_count
287+
# annotation won't be accurate because window functions are evaluated
288+
# before DISTINCT in SQL. Fall back to queryset.count() instead.
289+
if queryset.query.distinct:
290+
queryset = remove_window_pagination(queryset)
291+
return queryset.count()
292+
281293
try:
282294
return results[0]._strawberry_total_count
283295
except AttributeError:

tests/relay/test_cursor_pagination.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,6 +1419,152 @@ def test_invalid_cursor(cursor, test_objects):
14191419
assert result.errors[0].message == "Invalid cursor"
14201420

14211421

1422+
@pytest.mark.django_db(transaction=True)
1423+
def test_connection_distinct_with_m2m_filters():
1424+
from tests import models
1425+
1426+
@strawberry_django.filter_type(models.FruitType, lookups=True)
1427+
class FruitTypeFilter:
1428+
name: strawberry.auto
1429+
1430+
@strawberry_django.filter_type(models.Fruit, lookups=True)
1431+
class FruitFilter:
1432+
name: strawberry.auto
1433+
types: FruitTypeFilter | None
1434+
DISTINCT: bool | None
1435+
1436+
@strawberry_django.type(models.FruitType)
1437+
class FruitTypeGQL(Node):
1438+
name: strawberry.auto
1439+
1440+
@strawberry_django.type(models.Fruit, filters=FruitFilter)
1441+
class FruitGQL(Node):
1442+
name: strawberry.auto
1443+
types: list[FruitTypeGQL]
1444+
1445+
@strawberry.type
1446+
class FruitQuery:
1447+
fruits: DjangoCursorConnection[FruitGQL] = strawberry_django.connection()
1448+
1449+
fruit_schema = strawberry.Schema(
1450+
query=FruitQuery, extensions=[DjangoOptimizerExtension()]
1451+
)
1452+
1453+
ft1 = models.FruitType.objects.create(name="tropical")
1454+
ft2 = models.FruitType.objects.create(name="citrus")
1455+
1456+
banana = models.Fruit.objects.create(name="banana")
1457+
banana.types.add(ft1, ft2)
1458+
1459+
apple = models.Fruit.objects.create(name="apple")
1460+
apple.types.add(ft1)
1461+
1462+
orange = models.Fruit.objects.create(name="orange")
1463+
orange.types.add(ft2)
1464+
1465+
query = """
1466+
query TestQuery {
1467+
fruits(filters: {
1468+
types: { name: { inList: ["tropical", "citrus"] } }
1469+
DISTINCT: true
1470+
}) {
1471+
edges {
1472+
node {
1473+
id
1474+
name
1475+
}
1476+
}
1477+
totalCount
1478+
}
1479+
}
1480+
"""
1481+
result = fruit_schema.execute_sync(query)
1482+
assert not result.errors
1483+
assert result.data
1484+
1485+
edges = result.data["fruits"]["edges"]
1486+
total_count = result.data["fruits"]["totalCount"]
1487+
1488+
assert len(edges) == 3
1489+
assert total_count == 3
1490+
1491+
names = {edge["node"]["name"] for edge in edges}
1492+
assert names == {"banana", "apple", "orange"}
1493+
1494+
1495+
@pytest.mark.django_db(transaction=True)
1496+
def test_connection_distinct_with_pagination():
1497+
from tests import models
1498+
1499+
@strawberry_django.filter_type(models.FruitType, lookups=True)
1500+
class FruitTypeFilter:
1501+
name: strawberry.auto
1502+
1503+
@strawberry_django.filter_type(models.Fruit, lookups=True)
1504+
class FruitFilter:
1505+
name: strawberry.auto
1506+
types: FruitTypeFilter | None
1507+
DISTINCT: bool | None
1508+
1509+
@strawberry_django.type(models.FruitType)
1510+
class FruitTypeGQL(Node):
1511+
name: strawberry.auto
1512+
1513+
@strawberry_django.type(models.Fruit, filters=FruitFilter)
1514+
class FruitGQL(Node):
1515+
name: strawberry.auto
1516+
types: list[FruitTypeGQL]
1517+
1518+
@strawberry.type
1519+
class FruitQuery:
1520+
fruits: DjangoCursorConnection[FruitGQL] = strawberry_django.connection()
1521+
1522+
fruit_schema = strawberry.Schema(
1523+
query=FruitQuery, extensions=[DjangoOptimizerExtension()]
1524+
)
1525+
1526+
ft1 = models.FruitType.objects.create(name="tropical")
1527+
ft2 = models.FruitType.objects.create(name="citrus")
1528+
1529+
for i in range(5):
1530+
fruit = models.Fruit.objects.create(name=f"fruit_{i}")
1531+
fruit.types.add(ft1, ft2)
1532+
1533+
query = """
1534+
query TestQuery {
1535+
fruits(
1536+
first: 2
1537+
filters: {
1538+
types: { name: { inList: ["tropical", "citrus"] } }
1539+
DISTINCT: true
1540+
}
1541+
) {
1542+
edges {
1543+
node {
1544+
id
1545+
name
1546+
}
1547+
}
1548+
totalCount
1549+
pageInfo {
1550+
hasNextPage
1551+
}
1552+
}
1553+
}
1554+
"""
1555+
result = fruit_schema.execute_sync(query)
1556+
assert not result.errors
1557+
assert result.data
1558+
1559+
edges = result.data["fruits"]["edges"]
1560+
total_count = result.data["fruits"]["totalCount"]
1561+
has_next = result.data["fruits"]["pageInfo"]["hasNextPage"]
1562+
1563+
assert len(edges) == 2
1564+
assert total_count == 5
1565+
assert has_next is True
1566+
1567+
14221568
@pytest.mark.django_db(transaction=True)
14231569
@pytest.mark.parametrize(
14241570
("first", "last", "error_message"),

0 commit comments

Comments
 (0)