Skip to content

Fix Issue 4001: CSRF tokens are vulnerable to a BREACH attack #4042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.security.web.csrf;

import java.io.IOException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.HashSet;

Expand All @@ -28,6 +29,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.security.crypto.codec.Base64;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.access.AccessDeniedHandlerImpl;
import org.springframework.security.web.util.UrlUtils;
Expand Down Expand Up @@ -63,6 +65,8 @@ public final class CsrfFilter extends OncePerRequestFilter {
*/
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();

private static final SecureRandom secureRandom = new SecureRandom();

private final Log logger = LogFactory.getLog(getClass());
private final CsrfTokenRepository tokenRepository;
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
Expand Down Expand Up @@ -93,8 +97,9 @@ protected void doFilterInternal(HttpServletRequest request,
csrfToken = this.tokenRepository.generateToken(request);
this.tokenRepository.saveToken(csrfToken, request, response);
}
request.setAttribute(CsrfToken.class.getName(), csrfToken);
request.setAttribute(csrfToken.getParameterName(), csrfToken);
XorEncodedToken xorEncodedToken = new XorEncodedToken(csrfToken);
request.setAttribute(CsrfToken.class.getName(), xorEncodedToken);
request.setAttribute(xorEncodedToken.getParameterName(), xorEncodedToken);

if (!this.requireCsrfProtectionMatcher.matches(request)) {
filterChain.doFilter(request, response);
Expand All @@ -105,7 +110,7 @@ protected void doFilterInternal(HttpServletRequest request,
if (actualToken == null) {
actualToken = request.getParameter(csrfToken.getParameterName());
}
if (!csrfToken.getToken().equals(actualToken)) {
if (!csrfToken.getToken().equals(xorDecodeToken(actualToken))) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Invalid CSRF token found for "
+ UrlUtils.buildFullRequestUrl(request));
Expand Down Expand Up @@ -175,4 +180,63 @@ public boolean matches(HttpServletRequest request) {
return !this.allowedMethods.contains(request.getMethod());
}
}

static String xorEncodeToken(String token) {
// XOR the token with random values to protect against
// a BREACH attack
int tokenLength = token.length();
byte[] encodedToken = new byte[tokenLength*2];
byte[] salt = new byte[token.length()];
secureRandom.nextBytes(salt);

for (int i=0; i < tokenLength; i++) {
encodedToken[i] = salt[i];
encodedToken[i+tokenLength] = (byte)(token.charAt(i) ^ salt[i]);
}

return new String(Base64.encode(encodedToken));
}

static String xorDecodeToken(String encodedToken) {
if (encodedToken == null)
return null;

byte[] tokenBytes;
try {
tokenBytes = Base64.decode(encodedToken.getBytes());
} catch (IllegalArgumentException e) {
// If the Base64 decode failed then return null
return null;
}

StringBuilder builder = new StringBuilder();
int tokenLength = tokenBytes.length /2;
for (int i=0; i < tokenLength; i++)
builder.append((char)(tokenBytes[i] ^ tokenBytes[i+tokenLength]));
return builder.toString();
}

static class XorEncodedToken implements CsrfToken {

protected final CsrfToken delegate;

private XorEncodedToken(CsrfToken delegate) {
this.delegate = delegate;
}

@Override
public String getHeaderName() {
return delegate.getHeaderName();
}

@Override
public String getParameterName() {
return delegate.getParameterName();
}

@Override
public String getToken() {
return xorEncodeToken(delegate.getToken());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ public void doFilterAccessDeniedNoTokenPresent()

this.filter.doFilter(this.request, this.response, this.filterChain);

assertThat(this.request.getAttribute(this.token.getParameterName()))
assertToken(this.request.getAttribute(this.token.getParameterName()))
.isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);

verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
Expand All @@ -140,13 +140,13 @@ public void doFilterAccessDeniedIncorrectTokenPresent()
when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
this.request.setParameter(this.token.getParameterName(),
this.token.getToken() + " INVALID");
CsrfFilter.xorEncodeToken(this.token.getToken()).replaceAll("^......","INVALID"));

this.filter.doFilter(this.request, this.response, this.filterChain);

assertThat(this.request.getAttribute(this.token.getParameterName()))
assertToken(this.request.getAttribute(this.token.getParameterName()))
.isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);

verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
Expand All @@ -160,13 +160,13 @@ public void doFilterAccessDeniedIncorrectTokenPresentHeader()
when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
this.request.addHeader(this.token.getHeaderName(),
this.token.getToken() + " INVALID");
CsrfFilter.xorEncodeToken(this.token.getToken()).replaceAll("^......","INVALID"));

this.filter.doFilter(this.request, this.response, this.filterChain);

assertThat(this.request.getAttribute(this.token.getParameterName()))
assertToken(this.request.getAttribute(this.token.getParameterName()))
.isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);

verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
Expand All @@ -179,15 +179,16 @@ public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParamete
throws ServletException, IOException {
when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.request.setParameter(this.token.getParameterName(),
CsrfFilter.xorEncodeToken(this.token.getToken()));
this.request.addHeader(this.token.getHeaderName(),
this.token.getToken() + " INVALID");
CsrfFilter.xorEncodeToken(this.token.getToken()).replaceAll("^......","INVALID"));

this.filter.doFilter(this.request, this.response, this.filterChain);

assertThat(this.request.getAttribute(this.token.getParameterName()))
assertToken(this.request.getAttribute(this.token.getParameterName()))
.isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);

verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
Expand All @@ -203,9 +204,9 @@ public void doFilterNotCsrfRequestExistingToken()

this.filter.doFilter(this.request, this.response, this.filterChain);

assertThat(this.request.getAttribute(this.token.getParameterName()))
assertToken(this.request.getAttribute(this.token.getParameterName()))
.isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);

verify(this.filterChain).doFilter(this.request, this.response);
Expand Down Expand Up @@ -234,13 +235,14 @@ public void doFilterIsCsrfRequestExistingTokenHeader()
throws ServletException, IOException {
when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
this.request.addHeader(this.token.getHeaderName(),
CsrfFilter.xorEncodeToken(this.token.getToken()));

this.filter.doFilter(this.request, this.response, this.filterChain);

assertThat(this.request.getAttribute(this.token.getParameterName()))
assertToken(this.request.getAttribute(this.token.getParameterName()))
.isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);

verify(this.filterChain).doFilter(this.request, this.response);
Expand All @@ -253,14 +255,15 @@ public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
this.request.setParameter(this.token.getParameterName(),
this.token.getToken() + " INVALID");
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
CsrfFilter.xorEncodeToken(this.token.getToken()).replaceAll("^......","INVALID"));
this.request.addHeader(this.token.getHeaderName(),
CsrfFilter.xorEncodeToken(this.token.getToken()));

this.filter.doFilter(this.request, this.response, this.filterChain);

assertThat(this.request.getAttribute(this.token.getParameterName()))
assertToken(this.request.getAttribute(this.token.getParameterName()))
.isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);

verify(this.filterChain).doFilter(this.request, this.response);
Expand All @@ -272,13 +275,14 @@ public void doFilterIsCsrfRequestExistingToken()
throws ServletException, IOException {
when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.request.setParameter(this.token.getParameterName(),
CsrfFilter.xorEncodeToken(this.token.getToken()));

this.filter.doFilter(this.request, this.response, this.filterChain);

assertThat(this.request.getAttribute(this.token.getParameterName()))
assertToken(this.request.getAttribute(this.token.getParameterName()))
.isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);

verify(this.filterChain).doFilter(this.request, this.response);
Expand All @@ -292,7 +296,8 @@ public void doFilterIsCsrfRequestGenerateToken()
throws ServletException, IOException {
when(this.requestMatcher.matches(this.request)).thenReturn(true);
when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.request.setParameter(this.token.getParameterName(),
CsrfFilter.xorEncodeToken(this.token.getToken()));

this.filter.doFilter(this.request, this.response, this.filterChain);

Expand Down Expand Up @@ -381,9 +386,9 @@ public void doFilterDefaultAccessDenied() throws ServletException, IOException {

this.filter.doFilter(this.request, this.response, this.filterChain);

assertThat(this.request.getAttribute(this.token.getParameterName()))
assertToken(this.request.getAttribute(this.token.getParameterName()))
.isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName()))
assertToken(this.request.getAttribute(CsrfToken.class.getName()))
.isEqualTo(this.token);

assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
Expand All @@ -400,6 +405,18 @@ public void setAccessDeniedHandlerNull() {
this.filter.setAccessDeniedHandler(null);
}

@Test
public void encodeDecodeToken() {
String value = "sample";
String encodedValue1 = CsrfFilter.xorEncodeToken(value);
String encodedValue2 = CsrfFilter.xorEncodeToken(value);

assertThat(encodedValue1).isNotEqualTo(encodedValue2);

assertThat(value).isEqualTo(CsrfFilter.xorDecodeToken(encodedValue1));
assertThat(value).isEqualTo(CsrfFilter.xorDecodeToken(encodedValue2));
}

private static final CsrfTokenAssert assertToken(Object token) {
return new CsrfTokenAssert((CsrfToken) token);
}
Expand All @@ -420,7 +437,16 @@ public CsrfTokenAssert isEqualTo(CsrfToken expected) {
assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName());
assertThat(this.actual.getParameterName())
.isEqualTo(expected.getParameterName());
assertThat(this.actual.getToken()).isEqualTo(expected.getToken());

String expectedValue = expected.getToken();
if (expected instanceof CsrfFilter.XorEncodedToken)
expectedValue = CsrfFilter.xorDecodeToken(expectedValue);

String actualValue = this.actual.getToken();
if (this.actual instanceof CsrfFilter.XorEncodedToken)
actualValue = CsrfFilter.xorDecodeToken(actualValue);

assertThat(actualValue).isEqualTo(expectedValue);
return this;
}
}
Expand Down