From 4bcb48ce9ce270676cf521bd8e08a1bd4fd90069 Mon Sep 17 00:00:00 2001 From: ray0052 Date: Wed, 24 Aug 2016 09:12:31 -0600 Subject: [PATCH] Added protection against BREACH attack to CSRF tokens --- .../security/web/csrf/CsrfFilter.java | 70 +++++++++++++++- .../security/web/csrf/CsrfFilterTests.java | 82 ++++++++++++------- 2 files changed, 121 insertions(+), 31 deletions(-) diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 0a62e303bc3..627a4b8a49a 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -16,6 +16,7 @@ package org.springframework.security.web.csrf; import java.io.IOException; +import java.security.SecureRandom; import java.util.Arrays; import java.util.HashSet; @@ -28,6 +29,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.security.crypto.codec.Base64; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandlerImpl; import org.springframework.security.web.util.UrlUtils; @@ -63,6 +65,8 @@ public final class CsrfFilter extends OncePerRequestFilter { */ public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher(); + private static final SecureRandom secureRandom = new SecureRandom(); + private final Log logger = LogFactory.getLog(getClass()); private final CsrfTokenRepository tokenRepository; private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER; @@ -93,8 +97,9 @@ protected void doFilterInternal(HttpServletRequest request, csrfToken = this.tokenRepository.generateToken(request); this.tokenRepository.saveToken(csrfToken, request, response); } - request.setAttribute(CsrfToken.class.getName(), csrfToken); - request.setAttribute(csrfToken.getParameterName(), csrfToken); + XorEncodedToken xorEncodedToken = new XorEncodedToken(csrfToken); + request.setAttribute(CsrfToken.class.getName(), xorEncodedToken); + request.setAttribute(xorEncodedToken.getParameterName(), xorEncodedToken); if (!this.requireCsrfProtectionMatcher.matches(request)) { filterChain.doFilter(request, response); @@ -105,7 +110,7 @@ protected void doFilterInternal(HttpServletRequest request, if (actualToken == null) { actualToken = request.getParameter(csrfToken.getParameterName()); } - if (!csrfToken.getToken().equals(actualToken)) { + if (!csrfToken.getToken().equals(xorDecodeToken(actualToken))) { if (this.logger.isDebugEnabled()) { this.logger.debug("Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)); @@ -175,4 +180,63 @@ public boolean matches(HttpServletRequest request) { return !this.allowedMethods.contains(request.getMethod()); } } + + static String xorEncodeToken(String token) { + // XOR the token with random values to protect against + // a BREACH attack + int tokenLength = token.length(); + byte[] encodedToken = new byte[tokenLength*2]; + byte[] salt = new byte[token.length()]; + secureRandom.nextBytes(salt); + + for (int i=0; i < tokenLength; i++) { + encodedToken[i] = salt[i]; + encodedToken[i+tokenLength] = (byte)(token.charAt(i) ^ salt[i]); + } + + return new String(Base64.encode(encodedToken)); + } + + static String xorDecodeToken(String encodedToken) { + if (encodedToken == null) + return null; + + byte[] tokenBytes; + try { + tokenBytes = Base64.decode(encodedToken.getBytes()); + } catch (IllegalArgumentException e) { + // If the Base64 decode failed then return null + return null; + } + + StringBuilder builder = new StringBuilder(); + int tokenLength = tokenBytes.length /2; + for (int i=0; i < tokenLength; i++) + builder.append((char)(tokenBytes[i] ^ tokenBytes[i+tokenLength])); + return builder.toString(); + } + + static class XorEncodedToken implements CsrfToken { + + protected final CsrfToken delegate; + + private XorEncodedToken(CsrfToken delegate) { + this.delegate = delegate; + } + + @Override + public String getHeaderName() { + return delegate.getHeaderName(); + } + + @Override + public String getParameterName() { + return delegate.getParameterName(); + } + + @Override + public String getToken() { + return xorEncodeToken(delegate.getToken()); + } + } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index 970276788cd..1c6d4cbc91a 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -124,9 +124,9 @@ public void doFilterAccessDeniedNoTokenPresent() this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())) + assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())) + assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), @@ -140,13 +140,13 @@ public void doFilterAccessDeniedIncorrectTokenPresent() when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setParameter(this.token.getParameterName(), - this.token.getToken() + " INVALID"); + CsrfFilter.xorEncodeToken(this.token.getToken()).replaceAll("^......","INVALID")); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())) + assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())) + assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), @@ -160,13 +160,13 @@ public void doFilterAccessDeniedIncorrectTokenPresentHeader() when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.addHeader(this.token.getHeaderName(), - this.token.getToken() + " INVALID"); + CsrfFilter.xorEncodeToken(this.token.getToken()).replaceAll("^......","INVALID")); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())) + assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())) + assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), @@ -179,15 +179,16 @@ public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParamete throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); - this.request.setParameter(this.token.getParameterName(), this.token.getToken()); + this.request.setParameter(this.token.getParameterName(), + CsrfFilter.xorEncodeToken(this.token.getToken())); this.request.addHeader(this.token.getHeaderName(), - this.token.getToken() + " INVALID"); + CsrfFilter.xorEncodeToken(this.token.getToken()).replaceAll("^......","INVALID")); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())) + assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())) + assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), @@ -203,9 +204,9 @@ public void doFilterNotCsrfRequestExistingToken() this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())) + assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())) + assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); @@ -234,13 +235,14 @@ public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); - this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); + this.request.addHeader(this.token.getHeaderName(), + CsrfFilter.xorEncodeToken(this.token.getToken())); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())) + assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())) + assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); @@ -253,14 +255,15 @@ public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setParameter(this.token.getParameterName(), - this.token.getToken() + " INVALID"); - this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); + CsrfFilter.xorEncodeToken(this.token.getToken()).replaceAll("^......","INVALID")); + this.request.addHeader(this.token.getHeaderName(), + CsrfFilter.xorEncodeToken(this.token.getToken())); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())) + assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())) + assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); @@ -272,13 +275,14 @@ public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); - this.request.setParameter(this.token.getParameterName(), this.token.getToken()); + this.request.setParameter(this.token.getParameterName(), + CsrfFilter.xorEncodeToken(this.token.getToken())); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())) + assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())) + assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); @@ -292,7 +296,8 @@ public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token); - this.request.setParameter(this.token.getParameterName(), this.token.getToken()); + this.request.setParameter(this.token.getParameterName(), + CsrfFilter.xorEncodeToken(this.token.getToken())); this.filter.doFilter(this.request, this.response, this.filterChain); @@ -381,9 +386,9 @@ public void doFilterDefaultAccessDenied() throws ServletException, IOException { this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())) + assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())) + assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); @@ -400,6 +405,18 @@ public void setAccessDeniedHandlerNull() { this.filter.setAccessDeniedHandler(null); } + @Test + public void encodeDecodeToken() { + String value = "sample"; + String encodedValue1 = CsrfFilter.xorEncodeToken(value); + String encodedValue2 = CsrfFilter.xorEncodeToken(value); + + assertThat(encodedValue1).isNotEqualTo(encodedValue2); + + assertThat(value).isEqualTo(CsrfFilter.xorDecodeToken(encodedValue1)); + assertThat(value).isEqualTo(CsrfFilter.xorDecodeToken(encodedValue2)); + } + private static final CsrfTokenAssert assertToken(Object token) { return new CsrfTokenAssert((CsrfToken) token); } @@ -420,7 +437,16 @@ public CsrfTokenAssert isEqualTo(CsrfToken expected) { assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName()); assertThat(this.actual.getParameterName()) .isEqualTo(expected.getParameterName()); - assertThat(this.actual.getToken()).isEqualTo(expected.getToken()); + + String expectedValue = expected.getToken(); + if (expected instanceof CsrfFilter.XorEncodedToken) + expectedValue = CsrfFilter.xorDecodeToken(expectedValue); + + String actualValue = this.actual.getToken(); + if (this.actual instanceof CsrfFilter.XorEncodedToken) + actualValue = CsrfFilter.xorDecodeToken(actualValue); + + assertThat(actualValue).isEqualTo(expectedValue); return this; } }