Skip to content

Commit 487697f

Browse files
committed
Add filters
1 parent 33d5b74 commit 487697f

File tree

4 files changed

+126
-2
lines changed

4 files changed

+126
-2
lines changed

graphene_sqlalchemy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .types import SQLAlchemyObjectType
2-
from .fields import SQLAlchemyConnectionField
2+
from .fields import SQLAlchemyConnectionField, FilterableConnectionField
33
from .utils import get_query, get_session
44

55
__version__ = "2.1.0"
@@ -8,6 +8,7 @@
88
"__version__",
99
"SQLAlchemyObjectType",
1010
"SQLAlchemyConnectionField",
11+
"FilterableConnectionField",
1112
"get_query",
1213
"get_session",
1314
]

graphene_sqlalchemy/fields.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from graphene.relay.connection import PageInfo
77
from graphql_relay.connection.arrayconnection import connection_from_list_slice
88

9+
from .filters import filter_class_for_module, Filter
910
from .utils import get_query, sort_argument_for_model
1011

1112

@@ -94,6 +95,24 @@ def __init__(self, type, *args, **kwargs):
9495
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)
9596

9697

98+
class FilterableConnectionField(SQLAlchemyConnectionField):
99+
def __init__(self, type, *args, **kwargs):
100+
if 'filter' not in kwargs and issubclass(type, Connection):
101+
model = type.Edge.node._type._meta.model
102+
kwargs.setdefault('filter', filter_class_for_module(model))
103+
elif "filter" in kwargs and kwargs["filter"] is None:
104+
del kwargs["filter"]
105+
super(FilterableConnectionField, self).__init__(type, *args, **kwargs)
106+
107+
@classmethod
108+
def get_query(cls, model, info, filter=None, **kwargs):
109+
query = super(FilterableConnectionField, cls).get_query(model, info, **kwargs)
110+
if filter:
111+
for k, v in filter.items():
112+
query = Filter.add_filter_to_query(query, model, k, v)
113+
return query
114+
115+
97116
__connectionFactory = UnsortedSQLAlchemyConnectionField
98117

99118

graphene_sqlalchemy/filters.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import graphene
2+
3+
from collections import OrderedDict
4+
from graphene import Argument, Field
5+
from sqlalchemy import inspect
6+
7+
# Cache for the generated classes, to avoid name clash
8+
_INPUT_CACHE = {}
9+
_INPUT_FIELDS_CACHE = {}
10+
11+
12+
class Filter:
13+
@staticmethod
14+
def add_filter_to_query(query, model, field, value):
15+
[(operator, value)] = value.items()
16+
if operator == 'eq':
17+
query = query.filter(getattr(model, field) == value)
18+
elif operator == 'ne':
19+
query = query.filter(getattr(model, field) == value)
20+
elif operator == 'lt':
21+
query = query.filter(getattr(model, field) < value)
22+
elif operator == 'gt':
23+
query = query.filter(getattr(model, field) > value)
24+
elif operator == 'like':
25+
query = query.filter(getattr(model, field).like(value))
26+
return query
27+
28+
29+
def filter_class_for_module(cls):
30+
name = cls.__name__ + "InputFilter"
31+
if name in _INPUT_CACHE:
32+
return Argument(_INPUT_CACHE[name])
33+
34+
class InputFilterBase:
35+
pass
36+
37+
fields = OrderedDict()
38+
for column in inspect(cls).columns.values():
39+
maybe_field = create_input_filter_field(column)
40+
if maybe_field:
41+
fields[column.name] = maybe_field
42+
input_class = type(name, (InputFilterBase, graphene.InputObjectType), {})
43+
input_class._meta.fields.update(fields)
44+
_INPUT_CACHE[name] = input_class
45+
return Argument(input_class)
46+
47+
48+
def create_input_filter_field(column):
49+
from .converter import convert_sqlalchemy_type
50+
graphene_type = convert_sqlalchemy_type(column.type, column)
51+
if graphene_type.__class__ == Field: # TODO enum not supported
52+
return None
53+
name = str(graphene_type.__class__) + 'Filter'
54+
55+
if name in _INPUT_FIELDS_CACHE:
56+
return Field(_INPUT_FIELDS_CACHE[name])
57+
58+
field_class = Filter
59+
fields = OrderedDict()
60+
fields['eq'] = Field(graphene_type.__class__, description='Field should be equal to given value')
61+
fields['ne'] = Field(graphene_type.__class__, description='Field should not be equal to given value')
62+
fields['lt'] = Field(graphene_type.__class__, description='Field should be less then given value')
63+
fields['gt'] = Field(graphene_type.__class__, description='Field should be great then given value')
64+
fields['like'] = Field(graphene_type.__class__, description='Field should have a pattern of given value')
65+
# TODO construct operators based on __class__
66+
# TODO complex filter support: OR
67+
68+
field_class = type(name, (field_class, graphene.InputObjectType), {})
69+
field_class._meta.fields.update(fields)
70+
_INPUT_FIELDS_CACHE[name] = field_class
71+
return Field(field_class)

graphene_sqlalchemy/tests/test_query.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from graphene.relay import Connection, Node
77

88
from ..registry import reset_global_registry
9-
from ..fields import SQLAlchemyConnectionField
9+
from ..fields import SQLAlchemyConnectionField, FilterableConnectionField
1010
from ..types import SQLAlchemyObjectType
1111
from ..utils import sort_argument_for_model, sort_enum_for_model
1212
from .models import Article, Base, Editor, Pet, Reporter
@@ -484,3 +484,36 @@ def makeNodes(nodeList):
484484
node["node"]["name"] for node in expectedNoSort[key]["edges"]
485485
)
486486

487+
488+
def test_filter(session):
489+
sort_setup(session)
490+
491+
class PetNode(SQLAlchemyObjectType):
492+
class Meta:
493+
model = Pet
494+
interfaces = (Node,)
495+
496+
class PetConnection(Connection):
497+
class Meta:
498+
node = PetNode
499+
500+
class Query(graphene.ObjectType):
501+
pets = FilterableConnectionField(PetConnection)
502+
503+
only_lassie_query = """
504+
query {
505+
pets(filter: {name: {eq: "Lassie"}}) {
506+
edges {
507+
node {
508+
name
509+
}
510+
}
511+
}
512+
}
513+
"""
514+
schema = graphene.Schema(query=Query)
515+
result = schema.execute(only_lassie_query, context_value={"session": session})
516+
assert len(result.data['pets']['edges']) == 1
517+
assert result.data['pets']['edges'][0]['node']['name'] == 'Lassie'
518+
519+

0 commit comments

Comments
 (0)