From 487697fa8ff67c23c083318849a6fdd86ece4612 Mon Sep 17 00:00:00 2001 From: Val Tikhonov Date: Fri, 7 Sep 2018 12:01:14 +0200 Subject: [PATCH] Add filters --- graphene_sqlalchemy/__init__.py | 3 +- graphene_sqlalchemy/fields.py | 19 +++++++ graphene_sqlalchemy/filters.py | 71 +++++++++++++++++++++++++ graphene_sqlalchemy/tests/test_query.py | 35 +++++++++++- 4 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 graphene_sqlalchemy/filters.py diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 810150f0..97c09574 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ from .types import SQLAlchemyObjectType -from .fields import SQLAlchemyConnectionField +from .fields import SQLAlchemyConnectionField, FilterableConnectionField from .utils import get_query, get_session __version__ = "2.1.0" @@ -8,6 +8,7 @@ "__version__", "SQLAlchemyObjectType", "SQLAlchemyConnectionField", + "FilterableConnectionField", "get_query", "get_session", ] diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index bf3522b4..64c00a95 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -6,6 +6,7 @@ from graphene.relay.connection import PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice +from .filters import filter_class_for_module, Filter from .utils import get_query, sort_argument_for_model @@ -94,6 +95,24 @@ def __init__(self, type, *args, **kwargs): super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) +class FilterableConnectionField(SQLAlchemyConnectionField): + def __init__(self, type, *args, **kwargs): + if 'filter' not in kwargs and issubclass(type, Connection): + model = type.Edge.node._type._meta.model + kwargs.setdefault('filter', filter_class_for_module(model)) + elif "filter" in kwargs and kwargs["filter"] is None: + del kwargs["filter"] + super(FilterableConnectionField, self).__init__(type, *args, **kwargs) + + @classmethod + def get_query(cls, model, info, filter=None, **kwargs): + query = super(FilterableConnectionField, cls).get_query(model, info, **kwargs) + if filter: + for k, v in filter.items(): + query = Filter.add_filter_to_query(query, model, k, v) + return query + + __connectionFactory = UnsortedSQLAlchemyConnectionField diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py new file mode 100644 index 00000000..4dc83e2f --- /dev/null +++ b/graphene_sqlalchemy/filters.py @@ -0,0 +1,71 @@ +import graphene + +from collections import OrderedDict +from graphene import Argument, Field +from sqlalchemy import inspect + +# Cache for the generated classes, to avoid name clash +_INPUT_CACHE = {} +_INPUT_FIELDS_CACHE = {} + + +class Filter: + @staticmethod + def add_filter_to_query(query, model, field, value): + [(operator, value)] = value.items() + if operator == 'eq': + query = query.filter(getattr(model, field) == value) + elif operator == 'ne': + query = query.filter(getattr(model, field) == value) + elif operator == 'lt': + query = query.filter(getattr(model, field) < value) + elif operator == 'gt': + query = query.filter(getattr(model, field) > value) + elif operator == 'like': + query = query.filter(getattr(model, field).like(value)) + return query + + +def filter_class_for_module(cls): + name = cls.__name__ + "InputFilter" + if name in _INPUT_CACHE: + return Argument(_INPUT_CACHE[name]) + + class InputFilterBase: + pass + + fields = OrderedDict() + for column in inspect(cls).columns.values(): + maybe_field = create_input_filter_field(column) + if maybe_field: + fields[column.name] = maybe_field + input_class = type(name, (InputFilterBase, graphene.InputObjectType), {}) + input_class._meta.fields.update(fields) + _INPUT_CACHE[name] = input_class + return Argument(input_class) + + +def create_input_filter_field(column): + from .converter import convert_sqlalchemy_type + graphene_type = convert_sqlalchemy_type(column.type, column) + if graphene_type.__class__ == Field: # TODO enum not supported + return None + name = str(graphene_type.__class__) + 'Filter' + + if name in _INPUT_FIELDS_CACHE: + return Field(_INPUT_FIELDS_CACHE[name]) + + field_class = Filter + fields = OrderedDict() + fields['eq'] = Field(graphene_type.__class__, description='Field should be equal to given value') + fields['ne'] = Field(graphene_type.__class__, description='Field should not be equal to given value') + fields['lt'] = Field(graphene_type.__class__, description='Field should be less then given value') + fields['gt'] = Field(graphene_type.__class__, description='Field should be great then given value') + fields['like'] = Field(graphene_type.__class__, description='Field should have a pattern of given value') + # TODO construct operators based on __class__ + # TODO complex filter support: OR + + field_class = type(name, (field_class, graphene.InputObjectType), {}) + field_class._meta.fields.update(fields) + _INPUT_FIELDS_CACHE[name] = field_class + return Field(field_class) diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index f8bc8403..cd795c26 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -6,7 +6,7 @@ from graphene.relay import Connection, Node from ..registry import reset_global_registry -from ..fields import SQLAlchemyConnectionField +from ..fields import SQLAlchemyConnectionField, FilterableConnectionField from ..types import SQLAlchemyObjectType from ..utils import sort_argument_for_model, sort_enum_for_model from .models import Article, Base, Editor, Pet, Reporter @@ -484,3 +484,36 @@ def makeNodes(nodeList): node["node"]["name"] for node in expectedNoSort[key]["edges"] ) + +def test_filter(session): + sort_setup(session) + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node,) + + class PetConnection(Connection): + class Meta: + node = PetNode + + class Query(graphene.ObjectType): + pets = FilterableConnectionField(PetConnection) + + only_lassie_query = """ + query { + pets(filter: {name: {eq: "Lassie"}}) { + edges { + node { + name + } + } + } + } + """ + schema = graphene.Schema(query=Query) + result = schema.execute(only_lassie_query, context_value={"session": session}) + assert len(result.data['pets']['edges']) == 1 + assert result.data['pets']['edges'][0]['node']['name'] == 'Lassie' + +