Skip to content

Commit 3ee55cd

Browse files
committed
Fix broken pagination
1 parent 2375f6c commit 3ee55cd

File tree

1 file changed

+67
-22
lines changed

1 file changed

+67
-22
lines changed

rest_framework/pagination.py

Lines changed: 67 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,20 @@
22
Pagination serializers determine the structure of the output that should
33
be used for paginated responses.
44
"""
5+
import json
6+
import operator
7+
58
from base64 import b64decode, b64encode
69
from collections import OrderedDict, namedtuple
710
from urllib import parse
11+
from functools import reduce
812

913
from django.core.paginator import InvalidPage
1014
from django.core.paginator import Paginator as DjangoPaginator
1115
from django.template import loader
1216
from django.utils.encoding import force_str
1317
from django.utils.translation import gettext_lazy as _
18+
from django.db.models.query import Q
1419

1520
from rest_framework.compat import coreapi, coreschema
1621
from rest_framework.exceptions import NotFound
@@ -616,25 +621,42 @@ def paginate_queryset(self, queryset, request, view=None):
616621
else:
617622
(offset, reverse, current_position) = self.cursor
618623

619-
# Cursor pagination always enforces an ordering.
620-
if reverse:
621-
queryset = queryset.order_by(*_reverse_ordering(self.ordering))
622-
else:
623-
queryset = queryset.order_by(*self.ordering)
624-
625624
# If we have a cursor with a fixed position then filter by that.
626625
if current_position is not None:
627-
order = self.ordering[0]
628-
is_reversed = order.startswith('-')
629-
order_attr = order.lstrip('-')
626+
current_position_list = json.loads(current_position)
630627

631-
# Test for: (cursor reversed) XOR (queryset reversed)
632-
if self.cursor.reverse != is_reversed:
633-
kwargs = {order_attr + '__lt': current_position}
634-
else:
635-
kwargs = {order_attr + '__gt': current_position}
628+
q_objects = {"equals": {}, "exclusionary_compare": {}}
629+
630+
for order, position in zip(self.ordering, current_position_list):
631+
is_reversed = order.startswith("-")
632+
order_attr = order.lstrip("-")
633+
634+
q_objects["equals"][order] = Q(**{order_attr: position})
635+
636+
# Test for: (cursor reversed) XOR (queryset reversed)
637+
if self.cursor.reverse != is_reversed:
638+
q_objects["exclusionary_compare"][order] = Q(
639+
**{(order_attr + "__lt"): position}
640+
)
641+
else:
642+
q_objects["exclusionary_compare"][order] = Q(
643+
**{(order_attr + "__gt"): position}
644+
)
645+
646+
filter_list = [q_objects["exclusionary_compare"][self.ordering[0]]]
647+
# starting with the second field
648+
for i in range(2, len(self.ordering) + 2):
649+
# The first operands need to be equals
650+
# the last operands need to be gt
651+
equals = list(self.ordering[:i])
652+
greater_than_q = q_objects["exclusionary_compare"][equals.pop()]
653+
sub_filters = [q_objects["equals"][e] for e in equals]
654+
sub_filters.append(greater_than_q)
655+
filter_list.append(reduce(operator.and_, sub_filters))
656+
657+
# This only used a single Q object previously
658+
queryset = queryset.filter(reduce(operator.or_, filter_list))
636659

637-
queryset = queryset.filter(**kwargs)
638660

639661
# If we have an offset cursor then offset the entire page by that amount.
640662
# We also always fetch an extra item in order to determine if there is a
@@ -839,7 +861,14 @@ def get_ordering(self, request, queryset, view):
839861
)
840862

841863
if isinstance(ordering, str):
842-
return (ordering,)
864+
ordering = (ordering,)
865+
866+
pk_name = queryset.model._meta.pk.name
867+
868+
# Always include a unique key to order by
869+
if not {f"-{pk_name}", pk_name, "pk", "-pk"} & set(ordering):
870+
ordering = ordering + (pk_name)
871+
843872
return tuple(ordering)
844873

845874
def decode_cursor(self, request):
@@ -884,12 +913,28 @@ def encode_cursor(self, cursor):
884913
return replace_query_param(self.base_url, self.cursor_query_param, encoded)
885914

886915
def _get_position_from_instance(self, instance, ordering):
887-
field_name = ordering[0].lstrip('-')
888-
if isinstance(instance, dict):
889-
attr = instance[field_name]
890-
else:
891-
attr = getattr(instance, field_name)
892-
return str(attr)
916+
"""
917+
Overriden from the base class.
918+
This encodes the list data structure that's decoded
919+
on line 154 of this file.
920+
The old method simply return getattr(instnace, ordering[0]).
921+
This only works if the value in ordering[0] is unique.
922+
The value is json encoded here because it's an easy way to
923+
escape and serialize a list. This is then encoded as base64
924+
by encode_cursor, which calls this function.
925+
"""
926+
fields = []
927+
928+
for o in ordering:
929+
field_name = o.lstrip("-")
930+
if isinstance(instance, dict):
931+
attr = instance[field_name]
932+
else:
933+
attr = getattr(instance, field_name)
934+
935+
fields.append(str(attr))
936+
937+
return json.dumps(fields).encode()
893938

894939
def get_paginated_response(self, data):
895940
return Response(OrderedDict([

0 commit comments

Comments
 (0)