From 5cbe771a3cfbe3324602775a2b8c65a8a112e136 Mon Sep 17 00:00:00 2001 From: Marcus Da Coregio Date: Mon, 8 Aug 2022 15:34:17 -0300 Subject: [PATCH] Consistently handle RequestRejectedException if it is wrapped Closes gh-11645 --- .../security/web/FilterChainProxy.java | 14 ++++++++++++-- .../security/web/FilterChainProxyTests.java | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java index 2e76062fc3e..8a22469be2e 100644 --- a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java @@ -41,6 +41,7 @@ import org.springframework.security.web.firewall.RequestRejectedException; import org.springframework.security.web.firewall.RequestRejectedHandler; import org.springframework.security.web.firewall.StrictHttpFirewall; +import org.springframework.security.web.util.ThrowableAnalyzer; import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -158,6 +159,8 @@ public class FilterChainProxy extends GenericFilterBean { private RequestRejectedHandler requestRejectedHandler = new DefaultRequestRejectedHandler(); + private ThrowableAnalyzer throwableAnalyzer = new ThrowableAnalyzer(); + public FilterChainProxy() { } @@ -186,8 +189,15 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha request.setAttribute(FILTER_APPLIED, Boolean.TRUE); doFilterInternal(request, response, chain); } - catch (RequestRejectedException ex) { - this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex); + catch (Exception ex) { + Throwable[] causeChain = this.throwableAnalyzer.determineCauseChain(ex); + Throwable requestRejectedException = this.throwableAnalyzer + .getFirstThrowableOfType(RequestRejectedException.class, causeChain); + if (!(requestRejectedException instanceof RequestRejectedException)) { + throw ex; + } + this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, + (RequestRejectedException) requestRejectedException); } finally { this.securityContextHolderStrategy.clearContext(); diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index e08b84451d7..58169e6ee90 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -50,6 +50,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; @@ -262,4 +263,18 @@ public void requestRejectedHandlerIsCalledIfFirewallThrowsRequestRejectedExcepti verify(rjh).handle(eq(this.request), eq(this.response), eq((requestRejectedException))); } + @Test + public void requestRejectedHandlerIsCalledIfFirewallThrowsWrappedRequestRejectedException() throws Exception { + HttpFirewall fw = mock(HttpFirewall.class); + RequestRejectedHandler rjh = mock(RequestRejectedHandler.class); + this.fcp.setFirewall(fw); + this.fcp.setRequestRejectedHandler(rjh); + RequestRejectedException requestRejectedException = new RequestRejectedException("Contains illegal chars"); + ServletException servletException = new ServletException(requestRejectedException); + given(fw.getFirewalledRequest(this.request)).willReturn(mock(FirewalledRequest.class)); + willThrow(servletException).given(this.chain).doFilter(any(), any()); + this.fcp.doFilter(this.request, this.response, this.chain); + verify(rjh).handle(eq(this.request), eq(this.response), eq((requestRejectedException))); + } + }