|
1 | 1 | /*
|
2 |
| - * Copyright 2002-2019 the original author or authors. |
| 2 | + * Copyright 2002-2020 the original author or authors. |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
16 | 16 |
|
17 | 17 | package org.springframework.security.config.annotation.web.configurers.saml2;
|
18 | 18 |
|
| 19 | +import java.io.ByteArrayOutputStream; |
19 | 20 | import java.io.IOException;
|
| 21 | +import java.net.URLDecoder; |
20 | 22 | import java.time.Duration;
|
21 | 23 | import java.util.Arrays;
|
22 | 24 | import java.util.Base64;
|
23 | 25 | import java.util.Collection;
|
24 | 26 | import java.util.Collections;
|
| 27 | +import java.util.zip.Inflater; |
| 28 | +import java.util.zip.InflaterOutputStream; |
25 | 29 | import javax.servlet.ServletException;
|
26 | 30 | import javax.servlet.http.HttpServletRequest;
|
27 | 31 |
|
|
54 | 58 | import org.springframework.security.core.GrantedAuthority;
|
55 | 59 | import org.springframework.security.core.authority.SimpleGrantedAuthority;
|
56 | 60 | import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
|
| 61 | +import org.springframework.security.saml2.Saml2Exception; |
57 | 62 | import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
|
| 63 | +import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; |
58 | 64 | import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
|
59 | 65 | import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
| 66 | +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; |
60 | 67 | import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
61 | 68 | import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
62 | 69 | import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
|
69 | 76 | import org.springframework.security.web.context.SecurityContextRepository;
|
70 | 77 | import org.springframework.test.util.ReflectionTestUtils;
|
71 | 78 | import org.springframework.test.web.servlet.MockMvc;
|
| 79 | +import org.springframework.test.web.servlet.MvcResult; |
| 80 | +import org.springframework.web.util.UriComponents; |
| 81 | +import org.springframework.web.util.UriComponentsBuilder; |
72 | 82 |
|
| 83 | +import static java.nio.charset.StandardCharsets.UTF_8; |
73 | 84 | import static org.assertj.core.api.Assertions.assertThat;
|
74 | 85 | import static org.mockito.ArgumentMatchers.any;
|
75 | 86 | import static org.mockito.ArgumentMatchers.anyString;
|
@@ -157,6 +168,20 @@ public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() t
|
157 | 168 | verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
|
158 | 169 | }
|
159 | 170 |
|
| 171 | + @Test |
| 172 | + public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception { |
| 173 | + this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire(); |
| 174 | + |
| 175 | + MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")) |
| 176 | + .andReturn(); |
| 177 | + UriComponents components = UriComponentsBuilder |
| 178 | + .fromHttpUrl(result.getResponse().getRedirectedUrl()).build(); |
| 179 | + String samlRequest = components.getQueryParams().getFirst("SAMLRequest"); |
| 180 | + String decoded = URLDecoder.decode(samlRequest, "UTF-8"); |
| 181 | + String inflated = samlInflate(samlDecode(decoded)); |
| 182 | + assertThat(inflated).contains("ForceAuthn=\"true\""); |
| 183 | + } |
| 184 | + |
160 | 185 | private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
|
161 | 186 | // get the OpenSamlAuthenticationProvider
|
162 | 187 | Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
|
@@ -275,6 +300,29 @@ Saml2AuthenticationRequestContextResolver resolver() {
|
275 | 300 | }
|
276 | 301 | }
|
277 | 302 |
|
| 303 | + @EnableWebSecurity |
| 304 | + @Import(Saml2LoginConfigBeans.class) |
| 305 | + static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter { |
| 306 | + |
| 307 | + @Override |
| 308 | + protected void configure(HttpSecurity http) throws Exception { |
| 309 | + http |
| 310 | + .authorizeRequests(authz -> authz |
| 311 | + .anyRequest().authenticated() |
| 312 | + ) |
| 313 | + .saml2Login(saml2 -> {}); |
| 314 | + } |
| 315 | + |
| 316 | + @Bean |
| 317 | + Saml2AuthenticationRequestFactory authenticationRequestFactory() { |
| 318 | + OpenSamlAuthenticationRequestFactory authenticationRequestFactory = |
| 319 | + new OpenSamlAuthenticationRequestFactory(); |
| 320 | + authenticationRequestFactory.setAuthnRequestConsumerResolver( |
| 321 | + context -> authnRequest -> authnRequest.setForceAuthn(true)); |
| 322 | + return authenticationRequestFactory; |
| 323 | + } |
| 324 | + } |
| 325 | + |
278 | 326 | private static AuthenticationManager getAuthenticationManagerMock(String role) {
|
279 | 327 | return new AuthenticationManager() {
|
280 | 328 |
|
@@ -315,4 +363,23 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository() {
|
315 | 363 | }
|
316 | 364 | }
|
317 | 365 |
|
| 366 | + private static org.apache.commons.codec.binary.Base64 BASE64 = |
| 367 | + new org.apache.commons.codec.binary.Base64(0, new byte[]{'\n'}); |
| 368 | + |
| 369 | + private static byte[] samlDecode(String s) { |
| 370 | + return BASE64.decode(s); |
| 371 | + } |
| 372 | + |
| 373 | + private static String samlInflate(byte[] b) { |
| 374 | + try { |
| 375 | + ByteArrayOutputStream out = new ByteArrayOutputStream(); |
| 376 | + InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); |
| 377 | + iout.write(b); |
| 378 | + iout.finish(); |
| 379 | + return new String(out.toByteArray(), UTF_8); |
| 380 | + } |
| 381 | + catch (IOException e) { |
| 382 | + throw new Saml2Exception("Unable to inflate string", e); |
| 383 | + } |
| 384 | + } |
318 | 385 | }
|
0 commit comments