|
2 | 2 | Pagination serializers determine the structure of the output that should
|
3 | 3 | be used for paginated responses.
|
4 | 4 | """
|
| 5 | +import json |
| 6 | +import operator |
| 7 | + |
5 | 8 | from base64 import b64decode, b64encode
|
6 | 9 | from collections import OrderedDict, namedtuple
|
7 | 10 | from urllib import parse
|
| 11 | +from functools import reduce |
8 | 12 |
|
9 | 13 | from django.core.paginator import InvalidPage
|
10 | 14 | from django.core.paginator import Paginator as DjangoPaginator
|
11 | 15 | from django.template import loader
|
12 | 16 | from django.utils.encoding import force_str
|
13 | 17 | from django.utils.translation import gettext_lazy as _
|
| 18 | +from django.db.models.query import Q |
14 | 19 |
|
15 | 20 | from rest_framework.compat import coreapi, coreschema
|
16 | 21 | from rest_framework.exceptions import NotFound
|
@@ -616,25 +621,42 @@ def paginate_queryset(self, queryset, request, view=None):
|
616 | 621 | else:
|
617 | 622 | (offset, reverse, current_position) = self.cursor
|
618 | 623 |
|
619 |
| - # Cursor pagination always enforces an ordering. |
620 |
| - if reverse: |
621 |
| - queryset = queryset.order_by(*_reverse_ordering(self.ordering)) |
622 |
| - else: |
623 |
| - queryset = queryset.order_by(*self.ordering) |
624 |
| - |
625 | 624 | # If we have a cursor with a fixed position then filter by that.
|
626 | 625 | if current_position is not None:
|
627 |
| - order = self.ordering[0] |
628 |
| - is_reversed = order.startswith('-') |
629 |
| - order_attr = order.lstrip('-') |
| 626 | + current_position_list = json.loads(current_position) |
630 | 627 |
|
631 |
| - # Test for: (cursor reversed) XOR (queryset reversed) |
632 |
| - if self.cursor.reverse != is_reversed: |
633 |
| - kwargs = {order_attr + '__lt': current_position} |
634 |
| - else: |
635 |
| - kwargs = {order_attr + '__gt': current_position} |
| 628 | + q_objects = {"equals": {}, "exclusionary_compare": {}} |
| 629 | + |
| 630 | + for order, position in zip(self.ordering, current_position_list): |
| 631 | + is_reversed = order.startswith("-") |
| 632 | + order_attr = order.lstrip("-") |
| 633 | + |
| 634 | + q_objects["equals"][order] = Q(**{order_attr: position}) |
| 635 | + |
| 636 | + # Test for: (cursor reversed) XOR (queryset reversed) |
| 637 | + if self.cursor.reverse != is_reversed: |
| 638 | + q_objects["exclusionary_compare"][order] = Q( |
| 639 | + **{(order_attr + "__lt"): position} |
| 640 | + ) |
| 641 | + else: |
| 642 | + q_objects["exclusionary_compare"][order] = Q( |
| 643 | + **{(order_attr + "__gt"): position} |
| 644 | + ) |
| 645 | + |
| 646 | + filter_list = [q_objects["exclusionary_compare"][self.ordering[0]]] |
| 647 | + # starting with the second field |
| 648 | + for i in range(2, len(self.ordering) + 2): |
| 649 | + # The first operands need to be equals |
| 650 | + # the last operands need to be gt |
| 651 | + equals = list(self.ordering[:i]) |
| 652 | + greater_than_q = q_objects["exclusionary_compare"][equals.pop()] |
| 653 | + sub_filters = [q_objects["equals"][e] for e in equals] |
| 654 | + sub_filters.append(greater_than_q) |
| 655 | + filter_list.append(reduce(operator.and_, sub_filters)) |
| 656 | + |
| 657 | + # This only used a single Q object previously |
| 658 | + queryset = queryset.filter(reduce(operator.or_, filter_list)) |
636 | 659 |
|
637 |
| - queryset = queryset.filter(**kwargs) |
638 | 660 |
|
639 | 661 | # If we have an offset cursor then offset the entire page by that amount.
|
640 | 662 | # We also always fetch an extra item in order to determine if there is a
|
@@ -839,7 +861,14 @@ def get_ordering(self, request, queryset, view):
|
839 | 861 | )
|
840 | 862 |
|
841 | 863 | if isinstance(ordering, str):
|
842 |
| - return (ordering,) |
| 864 | + ordering = (ordering,) |
| 865 | + |
| 866 | + pk_name = queryset.model._meta.pk.name |
| 867 | + |
| 868 | + # Always include a unique key to order by |
| 869 | + if not {f"-{pk_name}", pk_name, "pk", "-pk"} & set(ordering): |
| 870 | + ordering = ordering + (pk_name) |
| 871 | + |
843 | 872 | return tuple(ordering)
|
844 | 873 |
|
845 | 874 | def decode_cursor(self, request):
|
@@ -884,12 +913,28 @@ def encode_cursor(self, cursor):
|
884 | 913 | return replace_query_param(self.base_url, self.cursor_query_param, encoded)
|
885 | 914 |
|
886 | 915 | def _get_position_from_instance(self, instance, ordering):
|
887 |
| - field_name = ordering[0].lstrip('-') |
888 |
| - if isinstance(instance, dict): |
889 |
| - attr = instance[field_name] |
890 |
| - else: |
891 |
| - attr = getattr(instance, field_name) |
892 |
| - return str(attr) |
| 916 | + """ |
| 917 | + Overriden from the base class. |
| 918 | + This encodes the list data structure that's decoded |
| 919 | + on line 154 of this file. |
| 920 | + The old method simply return getattr(instnace, ordering[0]). |
| 921 | + This only works if the value in ordering[0] is unique. |
| 922 | + The value is json encoded here because it's an easy way to |
| 923 | + escape and serialize a list. This is then encoded as base64 |
| 924 | + by encode_cursor, which calls this function. |
| 925 | + """ |
| 926 | + fields = [] |
| 927 | + |
| 928 | + for o in ordering: |
| 929 | + field_name = o.lstrip("-") |
| 930 | + if isinstance(instance, dict): |
| 931 | + attr = instance[field_name] |
| 932 | + else: |
| 933 | + attr = getattr(instance, field_name) |
| 934 | + |
| 935 | + fields.append(str(attr)) |
| 936 | + |
| 937 | + return json.dumps(fields).encode() |
893 | 938 |
|
894 | 939 | def get_paginated_response(self, data):
|
895 | 940 | return Response(OrderedDict([
|
|
0 commit comments