Skip to content

Add Validation to HTTP Inbound #2978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
* Spec for a {@link SimpleMessageListenerContainer}.
*
* @author Gary Russell
* @author Artem Bilan
*
* @since 5.0
*
*/
Expand Down Expand Up @@ -108,10 +110,22 @@ public SimpleMessageListenerContainerSpec receiveTimeout(long receiveTimeout) {
/**
* @param txSize the txSize.
* @return the spec.
* @see SimpleMessageListenerContainer#setTxSize(int)
* @see SimpleMessageListenerContainer#setBatchSize(int)
* @deprecated since 5.2 in favor of {@link #batchSize(int)}
*/
public SimpleMessageListenerContainerSpec txSize(int txSize) {
this.listenerContainer.setTxSize(txSize);
return batchSize(txSize);
}

/**
* The batch size to use.
* @param batchSize the batchSize.
* @return the spec.
* @see SimpleMessageListenerContainer#setBatchSize(int)
* @since 5.2
*/
public SimpleMessageListenerContainerSpec batchSize(int batchSize) {
this.listenerContainer.setBatchSize(batchSize);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,13 @@ protected void doParse(Element element, ParserContext parserContext, BeanDefinit
List<Element> headerElements = DomUtils.getChildElementsByTagName(element, "header");

if (!CollectionUtils.isEmpty(headerElements)) {
ManagedMap<String, Object> headerElementsMap = new ManagedMap<String, Object>();
ManagedMap<String, Object> headerElementsMap = new ManagedMap<>();
for (Element headerElement : headerElements) {
String name = headerElement.getAttribute(NAME_ATTRIBUTE);
BeanDefinition headerExpressionDef =
IntegrationNamespaceUtils.createExpressionDefIfAttributeDefined(IntegrationNamespaceUtils.EXPRESSION_ATTRIBUTE,
headerElement);
IntegrationNamespaceUtils
.createExpressionDefIfAttributeDefined(IntegrationNamespaceUtils.EXPRESSION_ATTRIBUTE,
headerElement);
if (headerExpressionDef != null) {
headerElementsMap.put(name, headerExpressionDef);
}
Expand All @@ -133,8 +134,8 @@ protected void doParse(Element element, ParserContext parserContext, BeanDefinit
}

BeanDefinition expressionDef =
IntegrationNamespaceUtils.createExpressionDefinitionFromValueOrExpression("view-name", "view-expression",
parserContext, element, false);
IntegrationNamespaceUtils.createExpressionDefinitionFromValueOrExpression("view-name",
"view-expression", parserContext, element, false);
if (expressionDef != null) {
builder.addPropertyValue("viewExpression", expressionDef);
}
Expand All @@ -154,13 +155,16 @@ protected void doParse(Element element, ParserContext parserContext, BeanDefinit

if (StringUtils.hasText(headerMapper)) {
if (hasMappedRequestHeaders || hasMappedResponseHeaders) {
parserContext.getReaderContext().error("Neither 'mapped-request-headers' or 'mapped-response-headers' " +
"attributes are allowed when a 'header-mapper' has been specified.", parserContext.extractSource(element));
parserContext.getReaderContext()
.error("Neither 'mapped-request-headers' or 'mapped-response-headers' " +
"attributes are allowed when a 'header-mapper' has been specified.",
parserContext.extractSource(element));
}
builder.addPropertyReference("headerMapper", headerMapper);
}
else {
BeanDefinitionBuilder headerMapperBuilder = BeanDefinitionBuilder.genericBeanDefinition(DefaultHttpHeaderMapper.class);
BeanDefinitionBuilder headerMapperBuilder =
BeanDefinitionBuilder.genericBeanDefinition(DefaultHttpHeaderMapper.class);
headerMapperBuilder.setFactoryMethod("inboundMapper");

if (hasMappedRequestHeaders) {
Expand All @@ -181,7 +185,7 @@ protected void doParse(Element element, ParserContext parserContext, BeanDefinit
if (crossOriginElement != null) {
BeanDefinitionBuilder crossOriginBuilder =
BeanDefinitionBuilder.genericBeanDefinition(CrossOrigin.class);
String[] attributes = {"origin", "allowed-headers", "exposed-headers", "max-age", "method"};
String[] attributes = { "origin", "allowed-headers", "exposed-headers", "max-age", "method" };
for (String crossOriginAttribute : attributes) {
IntegrationNamespaceUtils.setValueIfAttributeDefined(crossOriginBuilder, crossOriginElement,
crossOriginAttribute);
Expand All @@ -191,7 +195,8 @@ protected void doParse(Element element, ParserContext parserContext, BeanDefinit
builder.addPropertyValue("crossOrigin", crossOriginBuilder.getBeanDefinition());
}

IntegrationNamespaceUtils.setValueIfAttributeDefined(builder, element, "request-payload-type", "requestPayloadTypeClass");
IntegrationNamespaceUtils.setValueIfAttributeDefined(builder, element,
"request-payload-type", "requestPayloadTypeClass");

BeanDefinition statusCodeExpressionDef =
IntegrationNamespaceUtils.createExpressionDefIfAttributeDefined("status-code-expression", element);
Expand All @@ -205,14 +210,16 @@ protected void doParse(Element element, ParserContext parserContext, BeanDefinit

IntegrationNamespaceUtils.setValueIfAttributeDefined(builder, element, IntegrationNamespaceUtils.AUTO_STARTUP);
IntegrationNamespaceUtils.setValueIfAttributeDefined(builder, element, IntegrationNamespaceUtils.PHASE);
IntegrationNamespaceUtils.setReferenceIfAttributeDefined(builder, element, "validator");
}

private String getInputChannelAttributeName() {
return this.expectReply ? "request-channel" : "channel";
}

private BeanDefinition createRequestMapping(Element element) {
BeanDefinitionBuilder requestMappingDefBuilder = BeanDefinitionBuilder.genericBeanDefinition(RequestMapping.class);
BeanDefinitionBuilder requestMappingDefBuilder =
BeanDefinitionBuilder.genericBeanDefinition(RequestMapping.class);

String methods = element.getAttribute("supported-methods");
if (StringUtils.hasText(methods)) {
Expand All @@ -224,7 +231,7 @@ private BeanDefinition createRequestMapping(Element element) {
Element requestMappingElement = DomUtils.getChildElementByTagName(element, "request-mapping");

if (requestMappingElement != null) {
for (String requestMappingAttribute : new String[]{"params", "headers", "consumes", "produces"}) {
for (String requestMappingAttribute : new String[] { "params", "headers", "consumes", "produces" }) {
IntegrationNamespaceUtils.setValueIfAttributeDefined(requestMappingDefBuilder, requestMappingElement,
requestMappingAttribute);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.integration.http.support.DefaultHttpHeaderMapper;
import org.springframework.integration.mapping.HeaderMapper;
import org.springframework.util.Assert;
import org.springframework.validation.Validator;
import org.springframework.web.bind.annotation.RequestMethod;

/**
Expand All @@ -45,7 +46,8 @@
*
* @since 5.0
*/
public abstract class HttpInboundEndpointSupportSpec<S extends HttpInboundEndpointSupportSpec<S, E>, E extends BaseHttpInboundEndpoint>
public abstract class HttpInboundEndpointSupportSpec<S extends HttpInboundEndpointSupportSpec<S, E>,
E extends BaseHttpInboundEndpoint>
extends MessagingGatewaySpec<S, E>
implements ComponentsRegistration {

Expand Down Expand Up @@ -93,7 +95,7 @@ public S crossOrigin(Consumer<CrossOriginSpec> crossOrigin) {
* Specify a SpEL expression to evaluate in order to generate the Message payload.
* @param payloadExpression The payload expression.
* @return the spec
* @see org.springframework.integration.http.inbound.HttpRequestHandlingEndpointSupport#setPayloadExpression(Expression)
* @see BaseHttpInboundEndpoint#setPayloadExpression(Expression)
*/
public S payloadExpression(String payloadExpression) {
return payloadExpression(PARSER.parseExpression(payloadExpression));
Expand All @@ -103,7 +105,7 @@ public S payloadExpression(String payloadExpression) {
* Specify a SpEL expression to evaluate in order to generate the Message payload.
* @param payloadExpression The payload expression.
* @return the spec
* @see org.springframework.integration.http.inbound.HttpRequestHandlingEndpointSupport#setPayloadExpression(Expression)
* @see BaseHttpInboundEndpoint#setPayloadExpression(Expression)
*/
public S payloadExpression(Expression payloadExpression) {
this.target.setPayloadExpression(payloadExpression);
Expand All @@ -115,7 +117,7 @@ public S payloadExpression(Expression payloadExpression) {
* @param payloadFunction The payload {@link Function}.
* @param <P> the expected HTTP request body type.
* @return the spec
* @see org.springframework.integration.http.inbound.HttpRequestHandlingEndpointSupport#setPayloadExpression(Expression)
* @see BaseHttpInboundEndpoint#setPayloadExpression(Expression)
*/
public <P> S payloadFunction(Function<HttpEntity<P>, ?> payloadFunction) {
return payloadExpression(new FunctionExpression<>(payloadFunction));
Expand All @@ -125,7 +127,7 @@ public <P> S payloadFunction(Function<HttpEntity<P>, ?> payloadFunction) {
* Specify a Map of SpEL expressions to evaluate in order to generate the Message headers.
* @param expressions The {@link Map} of SpEL expressions for headers.
* @return the spec
* @see org.springframework.integration.http.inbound.HttpRequestHandlingEndpointSupport#setHeaderExpressions(Map)
* @see BaseHttpInboundEndpoint#setHeaderExpressions(Map)
*/
public S headerExpressions(Map<String, Expression> expressions) {
Assert.notNull(expressions, "'headerExpressions' must not be null");
Expand All @@ -139,7 +141,7 @@ public S headerExpressions(Map<String, Expression> expressions) {
* @param header the header name to populate.
* @param expression the SpEL expression for the header.
* @return the spec
* @see org.springframework.integration.http.inbound.HttpRequestHandlingEndpointSupport#setHeaderExpressions(Map)
* @see BaseHttpInboundEndpoint#setHeaderExpressions(Map)
*/
public S headerExpression(String header, String expression) {
return headerExpression(header, PARSER.parseExpression(expression));
Expand All @@ -150,7 +152,7 @@ public S headerExpression(String header, String expression) {
* @param header the header name to populate.
* @param expression the SpEL expression for the header.
* @return the spec
* @see org.springframework.integration.http.inbound.HttpRequestHandlingEndpointSupport#setHeaderExpressions(Map)
* @see BaseHttpInboundEndpoint#setHeaderExpressions(Map)
*/
public S headerExpression(String header, Expression expression) {
this.headerExpressions.put(header, expression);
Expand All @@ -163,7 +165,7 @@ public S headerExpression(String header, Expression expression) {
* @param headerFunction the function to evaluate the header value against {@link HttpEntity}.
* @param <P> the expected HTTP body type.
* @return the current Spec.
* @see org.springframework.integration.http.inbound.HttpRequestHandlingEndpointSupport#setHeaderExpressions(Map)
* @see BaseHttpInboundEndpoint#setHeaderExpressions(Map)
*/
public <P> S headerFunction(String header, Function<HttpEntity<P>, ?> headerFunction) {
return headerExpression(header, new FunctionExpression<>(headerFunction));
Expand Down Expand Up @@ -251,7 +253,7 @@ public S extractReplyPayload(boolean extractReplyPayload) {
* the default '200 OK' or '500 Internal Server Error' for a timeout.
* @param statusCodeExpression The status code Expression.
* @return the current Spec.
* @see org.springframework.integration.http.inbound.HttpRequestHandlingEndpointSupport#setStatusCodeExpression(Expression)
* @see BaseHttpInboundEndpoint#setStatusCodeExpression(Expression)
*/
public S statusCodeExpression(String statusCodeExpression) {
this.target.setStatusCodeExpressionString(statusCodeExpression);
Expand All @@ -263,7 +265,7 @@ public S statusCodeExpression(String statusCodeExpression) {
* the default '200 OK' or '500 Internal Server Error' for a timeout.
* @param statusCodeExpression The status code Expression.
* @return the current Spec.
* @see org.springframework.integration.http.inbound.HttpRequestHandlingEndpointSupport#setStatusCodeExpression(Expression)
* @see BaseHttpInboundEndpoint#setStatusCodeExpression(Expression)
*/
public S statusCodeExpression(Expression statusCodeExpression) {
this.target.setStatusCodeExpression(statusCodeExpression);
Expand All @@ -275,12 +277,23 @@ public S statusCodeExpression(Expression statusCodeExpression) {
* the default '200 OK' or '500 Internal Server Error' for a timeout.
* @param statusCodeFunction The status code {@link Function}.
* @return the current Spec.
* @see org.springframework.integration.http.inbound.HttpRequestHandlingEndpointSupport#setStatusCodeExpression(Expression)
* @see BaseHttpInboundEndpoint#setStatusCodeExpression(Expression)
*/
public S statusCodeFunction(Function<RequestEntity<?>, ?> statusCodeFunction) {
return statusCodeExpression(new FunctionExpression<>(statusCodeFunction));
}

/**
* Specify a {@link Validator} to validate a converted payload from request.
* @param validator the {@link Validator} to use.
* @return the spec
* @since 5.2
*/
public S validator(Validator validator) {
this.target.setValidator(validator);
return _this();
}

@Override
public Map<Object, String> getComponentsToRegister() {
HeaderMapper<HttpHeaders> headerMapperToRegister =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@
import org.springframework.integration.expression.ExpressionUtils;
import org.springframework.integration.gateway.MessagingGatewaySupport;
import org.springframework.integration.http.support.DefaultHttpHeaderMapper;
import org.springframework.integration.http.support.IntegrationWebExchangeBindException;
import org.springframework.integration.mapping.HeaderMapper;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.validation.BeanPropertyBindingResult;
import org.springframework.validation.ValidationUtils;
import org.springframework.validation.Validator;

/**
* The {@link MessagingGatewaySupport} extension for HTTP Inbound endpoints
Expand Down Expand Up @@ -65,6 +69,8 @@ public class BaseHttpInboundEndpoint extends MessagingGatewaySupport implements

private final boolean expectReply;

private Validator validator;

private ResolvableType requestPayloadType = null;

private HeaderMapper<HttpHeaders> headerMapper = DefaultHttpHeaderMapper.inboundMapper();
Expand Down Expand Up @@ -253,6 +259,19 @@ protected Expression getStatusCodeExpression() {
return this.statusCodeExpression;
}

/**
* Specify a {@link Validator} to validate a converted payload from request.
* @param validator the {@link Validator} to use.
* @since 5.2
*/
public void setValidator(Validator validator) {
this.validator = validator;
}

protected Validator getValidator() {
return this.validator;
}

@Override
protected void onInit() {
super.onInit();
Expand Down Expand Up @@ -334,4 +353,12 @@ protected boolean isReadable(HttpMethod httpMethod) {
return !(CollectionUtils.containsInstance(NON_READABLE_BODY_HTTP_METHODS, httpMethod));
}

protected void validate(Object value) {
BeanPropertyBindingResult errors = new BeanPropertyBindingResult(value, "requestPayload");
ValidationUtils.invokeValidator(this.validator, value, errors);
if (errors.hasErrors()) {
throw new IntegrationWebExchangeBindException(getComponentName(), value, errors);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ private Message<?> prepareRequestMessage(HttpServletRequest servletRequest, Requ

AbstractIntegrationMessageBuilder<?> messageBuilder;

if (getValidator() != null) {
validate(payload);
}

if (payload instanceof Message<?>) {
messageBuilder =
getMessageBuilderFactory()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,18 @@
</xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="validator" type="xsd:string">
<xsd:annotation>
<xsd:appinfo>
<tool:annotation kind="ref">
<tool:expected-type type="org.springframework.validation.Validator" />
</tool:annotation>
</xsd:appinfo>
<xsd:documentation>
A 'Validator' bean reference to validate a payload converted from the HTTP request.
</xsd:documentation>
</xsd:annotation>
</xsd:attribute>
</xsd:attributeGroup>

<xsd:element name="outbound-channel-adapter">
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<beans:beans
xmlns="http://www.springframework.org/schema/integration/http"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:beans="http://www.springframework.org/schema/beans"
xmlns:si="http://www.springframework.org/schema/integration"
xmlns:util="http://www.springframework.org/schema/util"
xsi:schemaLocation="http://www.springframework.org/schema/beans
xmlns="http://www.springframework.org/schema/integration/http"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:beans="http://www.springframework.org/schema/beans"
xmlns:si="http://www.springframework.org/schema/integration"
xmlns:util="http://www.springframework.org/schema/util"
xsi:schemaLocation="http://www.springframework.org/schema/beans
https://www.springframework.org/schema/beans/spring-beans.xsd
http://www.springframework.org/schema/integration
https://www.springframework.org/schema/integration/spring-integration.xsd
Expand All @@ -20,20 +20,25 @@
<si:queue capacity="1"/>
</si:channel>

<beans:bean id="validator" class="org.mockito.Mockito" factory-method="mock">
<beans:constructor-arg value="org.springframework.validation.Validator"/>
</beans:bean>

<inbound-channel-adapter id="defaultAdapter" channel="requests" error-channel="errorChannel"
auto-startup="false"
phase="1001"
status-code-expression="'101'"/>
auto-startup="false"
phase="1001"
status-code-expression="'101'"
validator="validator"/>

<inbound-channel-adapter id="postOnlyAdapter" path="/postOnly" channel="requests" supported-methods="POST"/>

<inbound-channel-adapter id="adapterWithCustomConverterWithDefaults" message-converters="customConverters"
channel="requests" supported-methods="DELETE" merge-with-default-converters="true"/>
channel="requests" supported-methods="DELETE" merge-with-default-converters="true"/>

<inbound-channel-adapter id="adapterWithCustomConverterNoDefaults" message-converters="customConverters"
channel="requests" supported-methods="HEAD" />
channel="requests" supported-methods="HEAD"/>

<inbound-channel-adapter id="adapterNoCustomConverterNoDefaults" channel="requests" supported-methods="POST" />
<inbound-channel-adapter id="adapterNoCustomConverterNoDefaults" channel="requests" supported-methods="POST"/>

<util:list id="customConverters">
<beans:bean class="org.springframework.integration.http.converter.SerializingHttpMessageConverter"/>
Expand All @@ -42,7 +47,7 @@
<inbound-channel-adapter id="putOrDeleteAdapter" channel="requests" supported-methods="PUT, delete"/>

<inbound-channel-adapter id="inboundController" channel="requests" view-name="foo" error-code="oops"
status-code-expression="T(org.springframework.http.HttpStatus).ACCEPTED">
status-code-expression="T(org.springframework.http.HttpStatus).ACCEPTED">
<request-mapping headers="BAR"/>
</inbound-channel-adapter>

Expand Down
Loading