Skip to content

Commit 6ecff6b

Browse files
committed
Saml2AuthenticationTokenConverter wrap SAMLResponse Base64 decode exception and inflate exception to Saml2AuthenticationException
Update copyright year to 2021 Closes spring-projectsgh-9310
1 parent 041e4aa commit 6ecff6b

File tree

4 files changed

+112
-35
lines changed

4 files changed

+112
-35
lines changed

config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

+42-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,26 +16,24 @@
1616

1717
package org.springframework.security.config.annotation.web.configurers.saml2;
1818

19-
import java.io.ByteArrayOutputStream;
2019
import java.io.IOException;
2120
import java.net.URLDecoder;
22-
import java.nio.charset.StandardCharsets;
2321
import java.time.Duration;
2422
import java.util.Arrays;
2523
import java.util.Base64;
2624
import java.util.Collection;
2725
import java.util.Collections;
28-
import java.util.zip.Inflater;
29-
import java.util.zip.InflaterOutputStream;
3026

3127
import javax.servlet.ServletException;
3228
import javax.servlet.http.HttpServletRequest;
29+
import javax.servlet.http.HttpServletResponse;
3330

3431
import org.junit.After;
3532
import org.junit.Assert;
3633
import org.junit.Before;
3734
import org.junit.Rule;
3835
import org.junit.Test;
36+
import org.mockito.ArgumentCaptor;
3937
import org.opensaml.saml.saml2.core.Assertion;
4038
import org.opensaml.saml.saml2.core.AuthnRequest;
4139

@@ -61,11 +59,13 @@
6159
import org.springframework.security.core.GrantedAuthority;
6260
import org.springframework.security.core.authority.SimpleGrantedAuthority;
6361
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
64-
import org.springframework.security.saml2.Saml2Exception;
62+
import org.springframework.security.saml2.core.Saml2ErrorCodes;
63+
import org.springframework.security.saml2.core.Saml2Utils;
6564
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
6665
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
6766
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory;
6867
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
68+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
6969
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
7070
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
7171
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
@@ -78,6 +78,7 @@
7878
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
7979
import org.springframework.security.web.FilterChainProxy;
8080
import org.springframework.security.web.authentication.AuthenticationConverter;
81+
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
8182
import org.springframework.security.web.context.HttpRequestResponseHolder;
8283
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
8384
import org.springframework.security.web.context.SecurityContextRepository;
@@ -188,7 +189,7 @@ public void authenticationRequestWhenAuthnRequestContextConverterThenUses() thro
188189
UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
189190
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
190191
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
191-
String inflated = samlInflate(samlDecode(decoded));
192+
String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
192193
assertThat(inflated).contains("ForceAuthn=\"true\"");
193194
}
194195

@@ -199,7 +200,7 @@ public void authenticateWhenCustomAuthenticationConverterThenUses() throws Excep
199200
.assertingPartyDetails((party) -> party.verificationX509Credentials(
200201
(c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential())))
201202
.build();
202-
String response = new String(samlDecode(SIGNED_RESPONSE));
203+
String response = new String(Saml2Utils.samlDecode(SIGNED_RESPONSE));
203204
given(CustomAuthenticationConverter.authenticationConverter.convert(any(HttpServletRequest.class)))
204205
.willReturn(new Saml2AuthenticationToken(relyingPartyRegistration, response));
205206
// @formatter:off
@@ -210,6 +211,24 @@ public void authenticateWhenCustomAuthenticationConverterThenUses() throws Excep
210211
verify(CustomAuthenticationConverter.authenticationConverter).convert(any(HttpServletRequest.class));
211212
}
212213

214+
@Test
215+
public void authenticateWithInvalidDeflatedSAMLResponseThenFailureHandlerUses() throws Exception {
216+
this.spring.register(CustomAuthenticationFailureHandler.class).autowire();
217+
byte[] invalidDeflated = Saml2Utils.invalidSamlDeflate("response");
218+
String encoded = Saml2Utils.samlEncode(invalidDeflated);
219+
MockHttpServletRequestBuilder request = get("/login/saml2/sso/registration-id").queryParam("SAMLResponse",
220+
encoded);
221+
this.mvc.perform(request);
222+
ArgumentCaptor<Saml2AuthenticationException> captor = ArgumentCaptor
223+
.forClass(Saml2AuthenticationException.class);
224+
verify(CustomAuthenticationFailureHandler.authenticationFailureHandler).onAuthenticationFailure(
225+
any(HttpServletRequest.class), any(HttpServletResponse.class), captor.capture());
226+
Saml2AuthenticationException exception = captor.getValue();
227+
assertThat(exception.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE);
228+
assertThat(exception.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string");
229+
assertThat(exception.getCause()).isInstanceOf(IOException.class);
230+
}
231+
213232
private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
214233
// get the OpenSamlAuthenticationProvider
215234
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@@ -244,26 +263,6 @@ private void performSaml2Login(String expected) throws IOException, ServletExcep
244263
.hasToString(expected);
245264
}
246265

247-
private static org.apache.commons.codec.binary.Base64 BASE64 = new org.apache.commons.codec.binary.Base64(0,
248-
new byte[] { '\n' });
249-
250-
private static byte[] samlDecode(String s) {
251-
return BASE64.decode(s);
252-
}
253-
254-
private static String samlInflate(byte[] b) {
255-
try {
256-
ByteArrayOutputStream out = new ByteArrayOutputStream();
257-
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
258-
iout.write(b);
259-
iout.finish();
260-
return new String(out.toByteArray(), StandardCharsets.UTF_8);
261-
}
262-
catch (IOException ex) {
263-
throw new Saml2Exception("Unable to inflate string", ex);
264-
}
265-
}
266-
267266
private static AuthenticationManager getAuthenticationManagerMock(String role) {
268267
return new AuthenticationManager() {
269268
@Override
@@ -314,6 +313,21 @@ public <O extends OpenSamlAuthenticationProvider> O postProcess(O provider) {
314313

315314
}
316315

316+
@EnableWebSecurity
317+
@Import(Saml2LoginConfigBeans.class)
318+
static class CustomAuthenticationFailureHandler extends WebSecurityConfigurerAdapter {
319+
320+
static final AuthenticationFailureHandler authenticationFailureHandler = mock(
321+
AuthenticationFailureHandler.class);
322+
323+
@Override
324+
protected void configure(HttpSecurity http) throws Exception {
325+
http.authorizeRequests((authz) -> authz.anyRequest().authenticated())
326+
.saml2Login((saml2) -> saml2.failureHandler(authenticationFailureHandler));
327+
}
328+
329+
}
330+
317331
@EnableWebSecurity
318332
@Import(Saml2LoginConfigBeans.class)
319333
static class CustomAuthenticationRequestContextResolver extends WebSecurityConfigurerAdapter {

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -28,7 +28,9 @@
2828

2929
import org.springframework.core.convert.converter.Converter;
3030
import org.springframework.http.HttpMethod;
31-
import org.springframework.security.saml2.Saml2Exception;
31+
import org.springframework.security.saml2.core.Saml2Error;
32+
import org.springframework.security.saml2.core.Saml2ErrorCodes;
33+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
3234
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
3335
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
3436
import org.springframework.security.web.authentication.AuthenticationConverter;
@@ -82,8 +84,14 @@ private String inflateIfRequired(HttpServletRequest request, byte[] b) {
8284
return new String(b, StandardCharsets.UTF_8);
8385
}
8486

85-
private byte[] samlDecode(String s) {
86-
return BASE64.decode(s);
87+
private byte[] samlDecode(String base64SAML2Response) {
88+
try {
89+
return BASE64.decode(base64SAML2Response);
90+
}
91+
catch (IllegalArgumentException ex) {
92+
throw new Saml2AuthenticationException(
93+
new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Failed to decode SAMLResponse"), ex);
94+
}
8795
}
8896

8997
private String samlInflate(byte[] b) {
@@ -95,7 +103,8 @@ private String samlInflate(byte[] b) {
95103
return new String(out.toByteArray(), StandardCharsets.UTF_8);
96104
}
97105
catch (IOException ex) {
98-
throw new Saml2Exception("Unable to inflate string", ex);
106+
throw new Saml2AuthenticationException(
107+
new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Unable to inflate string"), ex);
99108
}
100109
}
101110

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -30,6 +30,8 @@
3030

3131
public final class Saml2Utils {
3232

33+
public static final String INVALID_BASE64 = "cmVzcG9 \n\t\r uc2U9+/";
34+
3335
private static Base64 BASE64 = new Base64(0, new byte[] { '\n' });
3436

3537
private Saml2Utils() {
@@ -70,4 +72,18 @@ public static String samlInflate(byte[] b) {
7072
}
7173
}
7274

75+
public static byte[] invalidSamlDeflate(String s) {
76+
try {
77+
ByteArrayOutputStream out = new ByteArrayOutputStream();
78+
DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(out,
79+
new Deflater(Deflater.DEFLATED, false));
80+
deflaterOutputStream.write(s.getBytes(StandardCharsets.UTF_8));
81+
deflaterOutputStream.finish();
82+
return out.toByteArray();
83+
}
84+
catch (IOException ex) {
85+
throw new Saml2Exception("Unable to deflate string", ex);
86+
}
87+
}
88+
7389
}

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -29,14 +29,17 @@
2929
import org.springframework.core.convert.converter.Converter;
3030
import org.springframework.core.io.ClassPathResource;
3131
import org.springframework.mock.web.MockHttpServletRequest;
32+
import org.springframework.security.saml2.core.Saml2ErrorCodes;
3233
import org.springframework.security.saml2.core.Saml2Utils;
34+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
3335
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
3436
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
3537
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
3638
import org.springframework.util.StreamUtils;
3739
import org.springframework.web.util.UriUtils;
3840

3941
import static org.assertj.core.api.Assertions.assertThat;
42+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
4043
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
4144
import static org.mockito.ArgumentMatchers.any;
4245
import static org.mockito.BDDMockito.given;
@@ -64,6 +67,22 @@ public void convertWhenSamlResponseThenToken() {
6467
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
6568
}
6669

70+
@Test
71+
public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() {
72+
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
73+
this.relyingPartyRegistrationResolver);
74+
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
75+
.willReturn(this.relyingPartyRegistration);
76+
MockHttpServletRequest request = new MockHttpServletRequest();
77+
request.setParameter("SAMLResponse", Saml2Utils.INVALID_BASE64);
78+
assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request))
79+
.withCauseInstanceOf(IllegalArgumentException.class)
80+
.satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode())
81+
.isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE))
82+
.satisfies((ex) -> assertThat(ex.getSaml2Error().getDescription())
83+
.isEqualTo("Failed to decode SAMLResponse"));
84+
}
85+
6786
@Test
6887
public void convertWhenNoSamlResponseThenNull() {
6988
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
@@ -100,6 +119,25 @@ public void convertWhenGetRequestThenInflates() {
100119
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
101120
}
102121

122+
@Test
123+
public void convertWhenGetRequestInvalidDeflatedThenSaml2AuthenticationException() {
124+
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
125+
this.relyingPartyRegistrationResolver);
126+
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
127+
.willReturn(this.relyingPartyRegistration);
128+
MockHttpServletRequest request = new MockHttpServletRequest();
129+
request.setMethod("GET");
130+
byte[] invalidDeflated = Saml2Utils.invalidSamlDeflate("response");
131+
String encoded = Saml2Utils.samlEncode(invalidDeflated);
132+
request.setParameter("SAMLResponse", encoded);
133+
assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request))
134+
.withCauseInstanceOf(IOException.class)
135+
.satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode())
136+
.isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE))
137+
.satisfies(
138+
(ex) -> assertThat(ex.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string"));
139+
}
140+
103141
@Test
104142
public void constructorWhenResolverIsNullThenIllegalArgument() {
105143
assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null));

0 commit comments

Comments
 (0)