|
1 | 1 | /* |
2 | | - * Copyright 2002-2019 the original author or authors. |
| 2 | + * Copyright 2002-2020 the original author or authors. |
3 | 3 | * |
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | * you may not use this file except in compliance with the License. |
|
16 | 16 |
|
17 | 17 | package org.springframework.security.config.annotation.web.configurers.saml2; |
18 | 18 |
|
| 19 | +import java.io.ByteArrayOutputStream; |
19 | 20 | import java.io.IOException; |
| 21 | +import java.net.URLDecoder; |
20 | 22 | import java.time.Duration; |
21 | 23 | import java.util.Arrays; |
22 | 24 | import java.util.Base64; |
23 | 25 | import java.util.Collection; |
24 | 26 | import java.util.Collections; |
| 27 | +import java.util.zip.Inflater; |
| 28 | +import java.util.zip.InflaterOutputStream; |
25 | 29 | import javax.servlet.ServletException; |
26 | 30 | import javax.servlet.http.HttpServletRequest; |
27 | 31 |
|
|
54 | 58 | import org.springframework.security.core.GrantedAuthority; |
55 | 59 | import org.springframework.security.core.authority.SimpleGrantedAuthority; |
56 | 60 | import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; |
| 61 | +import org.springframework.security.saml2.Saml2Exception; |
57 | 62 | import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; |
| 63 | +import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; |
58 | 64 | import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; |
59 | 65 | import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; |
| 66 | +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; |
60 | 67 | import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; |
61 | 68 | import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; |
62 | 69 | import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; |
|
69 | 76 | import org.springframework.security.web.context.SecurityContextRepository; |
70 | 77 | import org.springframework.test.util.ReflectionTestUtils; |
71 | 78 | import org.springframework.test.web.servlet.MockMvc; |
| 79 | +import org.springframework.test.web.servlet.MvcResult; |
| 80 | +import org.springframework.web.util.UriComponents; |
| 81 | +import org.springframework.web.util.UriComponentsBuilder; |
72 | 82 |
|
| 83 | +import static java.nio.charset.StandardCharsets.UTF_8; |
73 | 84 | import static org.assertj.core.api.Assertions.assertThat; |
74 | 85 | import static org.mockito.ArgumentMatchers.any; |
75 | 86 | import static org.mockito.ArgumentMatchers.anyString; |
@@ -157,6 +168,20 @@ public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() t |
157 | 168 | verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)); |
158 | 169 | } |
159 | 170 |
|
| 171 | + @Test |
| 172 | + public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception { |
| 173 | + this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire(); |
| 174 | + |
| 175 | + MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")) |
| 176 | + .andReturn(); |
| 177 | + UriComponents components = UriComponentsBuilder |
| 178 | + .fromHttpUrl(result.getResponse().getRedirectedUrl()).build(); |
| 179 | + String samlRequest = components.getQueryParams().getFirst("SAMLRequest"); |
| 180 | + String decoded = URLDecoder.decode(samlRequest, "UTF-8"); |
| 181 | + String inflated = samlInflate(samlDecode(decoded)); |
| 182 | + assertThat(inflated).contains("ForceAuthn=\"true\""); |
| 183 | + } |
| 184 | + |
160 | 185 | private void validateSaml2WebSsoAuthenticationFilterConfiguration() { |
161 | 186 | // get the OpenSamlAuthenticationProvider |
162 | 187 | Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain); |
@@ -275,6 +300,29 @@ Saml2AuthenticationRequestContextResolver resolver() { |
275 | 300 | } |
276 | 301 | } |
277 | 302 |
|
| 303 | + @EnableWebSecurity |
| 304 | + @Import(Saml2LoginConfigBeans.class) |
| 305 | + static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter { |
| 306 | + |
| 307 | + @Override |
| 308 | + protected void configure(HttpSecurity http) throws Exception { |
| 309 | + http |
| 310 | + .authorizeRequests(authz -> authz |
| 311 | + .anyRequest().authenticated() |
| 312 | + ) |
| 313 | + .saml2Login(saml2 -> {}); |
| 314 | + } |
| 315 | + |
| 316 | + @Bean |
| 317 | + Saml2AuthenticationRequestFactory authenticationRequestFactory() { |
| 318 | + OpenSamlAuthenticationRequestFactory authenticationRequestFactory = |
| 319 | + new OpenSamlAuthenticationRequestFactory(); |
| 320 | + authenticationRequestFactory.setAuthnRequestConsumerResolver( |
| 321 | + context -> authnRequest -> authnRequest.setForceAuthn(true)); |
| 322 | + return authenticationRequestFactory; |
| 323 | + } |
| 324 | + } |
| 325 | + |
278 | 326 | private static AuthenticationManager getAuthenticationManagerMock(String role) { |
279 | 327 | return new AuthenticationManager() { |
280 | 328 |
|
@@ -315,4 +363,23 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository() { |
315 | 363 | } |
316 | 364 | } |
317 | 365 |
|
| 366 | + private static org.apache.commons.codec.binary.Base64 BASE64 = |
| 367 | + new org.apache.commons.codec.binary.Base64(0, new byte[]{'\n'}); |
| 368 | + |
| 369 | + private static byte[] samlDecode(String s) { |
| 370 | + return BASE64.decode(s); |
| 371 | + } |
| 372 | + |
| 373 | + private static String samlInflate(byte[] b) { |
| 374 | + try { |
| 375 | + ByteArrayOutputStream out = new ByteArrayOutputStream(); |
| 376 | + InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); |
| 377 | + iout.write(b); |
| 378 | + iout.finish(); |
| 379 | + return new String(out.toByteArray(), UTF_8); |
| 380 | + } |
| 381 | + catch (IOException e) { |
| 382 | + throw new Saml2Exception("Unable to inflate string", e); |
| 383 | + } |
| 384 | + } |
318 | 385 | } |
0 commit comments