Skip to content

Commit f77b288

Browse files
committed
Merge branch 'DamianHeard-issue-73' into dev
2 parents 72837f2 + aa7b48a commit f77b288

File tree

6 files changed

+174
-30
lines changed

6 files changed

+174
-30
lines changed

tests/test_aiohttp/test_aiohttpparser.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
from webargs import fields
3+
from marshmallow import Schema
34

45
import asyncio
56

@@ -85,6 +86,25 @@ def echo_cookie(request):
8586
assert res.status_code == 200
8687
assert res.json == {'mycookie': 'foo'}
8788

89+
def test_parse_with_callable(app, client):
90+
91+
class MySchema(Schema):
92+
foo = fields.Field(missing=42)
93+
94+
def make_schema(req):
95+
return MySchema(context={'request': req})
96+
97+
@asyncio.coroutine
98+
def echo_parse(request):
99+
args = yield from parser.parse(make_schema, request)
100+
return jsonify(args)
101+
102+
app.router.add_route('GET', '/factory', echo_parse)
103+
104+
res = client.get('/factory')
105+
assert res.status_code == 200
106+
assert res.json == {'foo': 42}
107+
88108

89109
def test_use_args(app, client):
90110

@@ -104,6 +124,26 @@ def echo_use_args(request, args):
104124
assert res.json == {'name': 'Joe'}
105125

106126

127+
def test_use_args_with_callable(app, client):
128+
129+
class MySchema(Schema):
130+
foo = fields.Field(missing=42)
131+
132+
def make_schema(req):
133+
return MySchema(context={'request': req})
134+
135+
@asyncio.coroutine
136+
@use_args(make_schema)
137+
def echo_use_args(request, args):
138+
return jsonify(args)
139+
140+
app.router.add_route('GET', '/use_args', echo_use_args)
141+
142+
res = client.get('/use_args')
143+
assert res.status_code == 200
144+
assert res.json == {'foo': 42}
145+
146+
107147
def test_use_kwargs_on_method(app, client):
108148
class Handler:
109149

tests/test_core.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44

55
import pytest
6-
from marshmallow import Schema
6+
from marshmallow import Schema, post_load
77
from werkzeug.datastructures import MultiDict as WerkMultiDict
88

99
PY26 = sys.version_info[0] == 2 and int(sys.version_info[1]) < 7
@@ -278,7 +278,10 @@ def validate2(args):
278278
def test_required_with_custom_error(web_request):
279279
web_request.json = {}
280280
parser = MockRequestParser()
281-
args = {'foo': fields.Str(required='We need foo')}
281+
args = {'foo': fields.Str(
282+
required=True,
283+
error_messages={'required': 'We need foo'})
284+
}
282285
with pytest.raises(ValidationError) as excinfo:
283286
# Test that `validate` receives dictionary of args
284287
parser.parse(args, web_request, locations=('json', ))
@@ -441,6 +444,50 @@ def viewfunc(args):
441444
assert viewfunc() == {'username': 'foo', 'password': 'bar'}
442445

443446

447+
def test_parse_with_callable(web_request, parser):
448+
449+
web_request.json = {'foo': 42}
450+
451+
class MySchema(Schema):
452+
foo = fields.Field()
453+
454+
def make_schema(req):
455+
assert req is web_request
456+
return MySchema(context={'request': req})
457+
458+
result = parser.parse(make_schema, web_request)
459+
460+
assert result == {'foo': 42}
461+
462+
463+
def test_use_args_callable(web_request, parser):
464+
class HelloSchema(Schema):
465+
name = fields.Str()
466+
467+
class Meta(object):
468+
strict = True
469+
470+
@post_load
471+
def request_data(self, item):
472+
item['data'] = self.context['request'].data
473+
return item
474+
475+
web_request.json = {'name': 'foo'}
476+
web_request.data = 'request-data'
477+
478+
def make_schema(req):
479+
assert req is web_request
480+
return HelloSchema(context={'request': req})
481+
482+
@parser.use_args(
483+
make_schema,
484+
web_request,
485+
)
486+
def viewfunc(args):
487+
return args
488+
assert viewfunc() == {'name': 'foo', 'data': 'request-data'}
489+
490+
444491
class TestPassingSchema:
445492
class UserSchema(Schema):
446493
id = fields.Int(dump_only=True)

tests/test_pyramidparser.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
from webtest import TestApp
5+
from marshmallow import Schema, post_load
56
from pyramid.config import Configurator
67

78
from webargs import fields
@@ -22,6 +23,19 @@
2223
error_messages={'validator_failed': "Houston, we've had a problem."}),
2324
}
2425

26+
27+
class HelloSchema(Schema):
28+
name = fields.String()
29+
30+
@post_load
31+
def greet(self, item):
32+
item['url'] = self.context['request'].url
33+
return item
34+
35+
class Meta(object):
36+
strict = True
37+
38+
2539
@pytest.fixture
2640
def testapp():
2741
def echo(request):
@@ -69,6 +83,10 @@ def baz(request, myvalue):
6983
def matched(request, args):
7084
return args
7185

86+
@parser.use_args(lambda req: HelloSchema(context={'request': req}))
87+
def constructor(request, args):
88+
return args
89+
7290
config = Configurator()
7391

7492
config.add_route('echo', '/echo')
@@ -81,6 +99,7 @@ def matched(request, args):
8199
config.add_route('bar', '/bar')
82100
config.add_route('baz', '/baz')
83101
config.add_route('matched', '/matched/{mymatch:\d+}')
102+
config.add_route('constructor', '/constructor')
84103

85104
config.add_view(echo, route_name='echo', renderer='json')
86105
config.add_view(echomulti, route_name='echomulti', renderer='json')
@@ -92,6 +111,7 @@ def matched(request, args):
92111
config.add_view(Bar, route_name='bar', renderer='json')
93112
config.add_view(baz, route_name='baz', renderer='json')
94113
config.add_view(matched, route_name='matched', renderer='json')
114+
config.add_view(constructor, route_name='constructor', renderer='json')
95115

96116
app = config.make_wsgi_app()
97117

@@ -153,3 +173,8 @@ def test_user_kwargs_decorator(testapp):
153173
def test_parse_matchdict(testapp):
154174
res = testapp.get('/matched/1')
155175
assert res.json == {'mymatch': 1}
176+
177+
178+
def test_use_args_callable(testapp):
179+
res = testapp.post('/constructor', {'name': 'Jean-Luc Picard'})
180+
assert res.json == {'name': 'Jean-Luc Picard', 'url': 'http://localhost/constructor'}

webargs/async.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
"""Asynchronous request parser. Compatible with Python>=3.4."""
33
import asyncio
4+
import collections
45
import inspect
56
import functools
67

@@ -15,8 +16,8 @@ class AsyncParser(core.Parser):
1516
"""
1617

1718
@asyncio.coroutine
18-
def _parse_request(self, argmap, req, locations):
19-
argdict = argmap.fields if isinstance(argmap, ma.Schema) else argmap
19+
def _parse_request(self, schema, req, locations):
20+
argdict = schema.fields
2021
parsed = {}
2122
for argname, field_obj in iteritems(argdict):
2223
parsed_value = yield from self.parse_arg(argname, field_obj, req,
@@ -35,9 +36,10 @@ def parse(self, argmap, req=None, locations=None, validate=None, force_all=False
3536
assert req is not None, 'Must pass req object'
3637
ret = None
3738
validators = core._ensure_list_of_callables(validate)
39+
schema = self._get_schema(argmap, req)
3840
try:
39-
parsed = yield from self._parse_request(argmap, req, locations)
40-
result = self.load(parsed, argmap)
41+
parsed = yield from self._parse_request(schema=schema, req=req, locations=locations)
42+
result = self.load(parsed, schema)
4143
self._validate_arguments(result.data, validators)
4244
except ma.exceptions.ValidationError as error:
4345
self._on_validation_error(error)
@@ -60,11 +62,11 @@ def use_args(self, argmap, req=None, locations=None, as_kwargs=False, validate=N
6062
Receives the same arguments as `webargs.core.Parser.use_args`.
6163
"""
6264
locations = locations or self.locations
63-
if isinstance(argmap, ma.Schema):
64-
schema = argmap
65-
else:
66-
schema = core.argmap2schema(argmap)()
6765
request_obj = req
66+
# Optimization: If argmap is passed as a dictionary, we only need
67+
# to generate a Schema once
68+
if isinstance(argmap, collections.Mapping):
69+
argmap = core.argmap2schema(argmap)()
6870

6971
def decorator(func):
7072
req_ = request_obj
@@ -78,8 +80,10 @@ def wrapper(*args, **kwargs):
7880

7981
if not req_obj:
8082
req_obj = self.get_request_from_view_args(func, args, kwargs)
81-
parsed_args = yield from self.parse(schema, req=req_obj, locations=locations,
82-
validate=validate, force_all=force_all)
83+
# NOTE: At this point, argmap may be a Schema, callable, or dict
84+
parsed_args = yield from self.parse(argmap,
85+
req=req_obj, locations=locations,
86+
validate=validate, force_all=force_all)
8387
if as_kwargs:
8488
kwargs.update(parsed_args)
8589
return func(*args, **kwargs)

webargs/core.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import unicode_literals
33

4+
import collections
45
import functools
56
import inspect
67
import logging
@@ -220,8 +221,9 @@ def parse_arg(self, name, field, req, locations=None):
220221
return value
221222
return missing
222223

223-
def _parse_request(self, argmap, req, locations):
224-
argdict = argmap.fields if isinstance(argmap, ma.Schema) else argmap
224+
def _parse_request(self, schema, req, locations):
225+
"""Return a parsed arguments dictionary for the current request."""
226+
argdict = schema.fields
225227
parsed = {}
226228
for argname, field_obj in iteritems(argdict):
227229
parsed_value = self.parse_arg(argname, field_obj, req,
@@ -262,11 +264,28 @@ def _validate_arguments(self, data, validators):
262264
msg = self.DEFAULT_VALIDATION_MESSAGE
263265
raise ValidationError(msg, data=data)
264266

267+
def _get_schema(self, argmap, req):
268+
"""Return a `marshmallow.Schema` for the given argmap and request.
269+
270+
:param argmap: Either a `marshmallow.Schema`, `dict`, or callable that returns
271+
a `marshmallow.Schema` instance.
272+
:param req: The request object being parsed.
273+
:rtype: marshmallow.Schema
274+
"""
275+
if isinstance(argmap, ma.Schema):
276+
schema = argmap
277+
elif callable(argmap):
278+
schema = argmap(req)
279+
else:
280+
schema = argmap2schema(argmap)()
281+
return schema
282+
265283
def parse(self, argmap, req=None, locations=None, validate=None, force_all=False):
266284
"""Main request parsing method.
267285
268-
:param dict argmap: Either a `marshmallow.Schema` or a `dict`
269-
of argname -> `marshmallow.fields.Field` pairs.
286+
:param argmap: Either a `marshmallow.Schema`, `dict`
287+
of argname -> `marshmallow.fields.Field` pairs, or a callable that returns
288+
a `marshmallow.Schema` instance.
270289
:param req: The request object to parse.
271290
:param tuple locations: Where on the request to search for values.
272291
Can include one or more of ``('json', 'querystring', 'form',
@@ -281,9 +300,10 @@ def parse(self, argmap, req=None, locations=None, validate=None, force_all=False
281300
assert req is not None, 'Must pass req object'
282301
ret = None
283302
validators = _ensure_list_of_callables(validate)
303+
schema = self._get_schema(argmap, req)
284304
try:
285-
parsed = self._parse_request(argmap, req, locations)
286-
result = self.load(parsed, argmap)
305+
parsed = self._parse_request(schema=schema, req=req, locations=locations)
306+
result = self.load(parsed, schema)
287307
self._validate_arguments(result.data, validators)
288308
except ma.exceptions.ValidationError as error:
289309
self._on_validation_error(error)
@@ -330,20 +350,21 @@ def use_args(self, argmap, req=None, locations=None, as_kwargs=False, validate=N
330350
def greet(args):
331351
return 'Hello ' + args['name']
332352
333-
:param dict argmap: Either a `marshmallow.Schema` or a `dict`
334-
of argname -> `marshmallow.fields.Field` pairs.
353+
:param dict argmap: Either a `marshmallow.Schema`, a `dict`
354+
of argname -> `marshmallow.fields.Field` pairs, or a callable
355+
which accepts a request and returns a `marshmallow.Schema`.
335356
:param tuple locations: Where on the request to search for values.
336357
:param bool as_kwargs: Whether to insert arguments as keyword arguments.
337358
:param callable validate: Validation function that receives the dictionary
338359
of parsed arguments. If the function returns ``False``, the parser
339360
will raise a :exc:`ValidationError`.
340361
"""
341362
locations = locations or self.locations
342-
if isinstance(argmap, ma.Schema):
343-
schema = argmap
344-
else:
345-
schema = argmap2schema(argmap)()
346363
request_obj = req
364+
# Optimization: If argmap is passed as a dictionary, we only need
365+
# to generate a Schema once
366+
if isinstance(argmap, collections.Mapping):
367+
argmap = argmap2schema(argmap)()
347368

348369
def decorator(func):
349370
req_ = request_obj
@@ -357,8 +378,10 @@ def wrapper(*args, **kwargs):
357378

358379
if not req_obj:
359380
req_obj = self.get_request_from_view_args(func, args, kwargs)
360-
parsed_args = self.parse(schema, req=req_obj, locations=locations,
361-
validate=validate, force_all=force_all)
381+
# NOTE: At this point, argmap may be a Schema, callable, or dict
382+
parsed_args = self.parse(argmap, req=req_obj,
383+
locations=locations, validate=validate,
384+
force_all=force_all)
362385
if as_kwargs:
363386
kwargs.update(parsed_args)
364387
return func(*args, **kwargs)

webargs/pyramidparser.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ def use_args(self, argmap, req=None, locations=core.Parser.DEFAULT_LOCATIONS,
8989
Supports the *Class-based View* pattern where `request` is saved as an instance
9090
attribute on a view class.
9191
92-
:param dict argmap: Either a `marshmallow.Schema` or a `dict`
93-
of argname -> `marshmallow.fields.Field` pairs.
92+
:param dict argmap: Either a `marshmallow.Schema`, a `dict`
93+
of argname -> `marshmallow.fields.Field` pairs, or a callable
94+
which accepts a request and returns a `marshmallow.Schema`.
9495
:param req: The request object to parse. Pulled off of the view by default.
9596
:param tuple locations: Where on the request to search for values.
9697
:param bool as_kwargs: Whether to insert arguments as keyword arguments.
@@ -101,6 +102,8 @@ def use_args(self, argmap, req=None, locations=core.Parser.DEFAULT_LOCATIONS,
101102
locations = locations or self.locations
102103
if isinstance(argmap, ma.Schema):
103104
schema = argmap
105+
elif callable(argmap):
106+
schema = None
104107
else:
105108
schema = core.argmap2schema(argmap)()
106109

@@ -112,8 +115,10 @@ def wrapper(obj, *args, **kwargs):
112115
request = req or obj.request
113116
except AttributeError: # first arg is request
114117
request = obj
115-
parsed_args = self.parse(schema, req=request, locations=locations,
116-
validate=validate, force_all=as_kwargs)
118+
119+
parsed_args = self.parse(schema or argmap(request), req=request,
120+
locations=locations, validate=validate,
121+
force_all=as_kwargs)
117122
if as_kwargs:
118123
kwargs.update(parsed_args)
119124
return func(obj, *args, **kwargs)

0 commit comments

Comments
 (0)