Skip to content

Commit 2c960d2

Browse files
committed
Add AuthnRequestConsumerResolver
Closes gh-8141
1 parent 2e5c87d commit 2c960d2

File tree

3 files changed

+124
-3
lines changed

3 files changed

+124
-3
lines changed

config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

+68-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,12 +16,16 @@
1616

1717
package org.springframework.security.config.annotation.web.configurers.saml2;
1818

19+
import java.io.ByteArrayOutputStream;
1920
import java.io.IOException;
21+
import java.net.URLDecoder;
2022
import java.time.Duration;
2123
import java.util.Arrays;
2224
import java.util.Base64;
2325
import java.util.Collection;
2426
import java.util.Collections;
27+
import java.util.zip.Inflater;
28+
import java.util.zip.InflaterOutputStream;
2529
import javax.servlet.ServletException;
2630
import javax.servlet.http.HttpServletRequest;
2731

@@ -54,9 +58,12 @@
5458
import org.springframework.security.core.GrantedAuthority;
5559
import org.springframework.security.core.authority.SimpleGrantedAuthority;
5660
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
61+
import org.springframework.security.saml2.Saml2Exception;
5762
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
63+
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory;
5864
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
5965
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
66+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
6067
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
6168
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
6269
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
@@ -69,7 +76,11 @@
6976
import org.springframework.security.web.context.SecurityContextRepository;
7077
import org.springframework.test.util.ReflectionTestUtils;
7178
import org.springframework.test.web.servlet.MockMvc;
79+
import org.springframework.test.web.servlet.MvcResult;
80+
import org.springframework.web.util.UriComponents;
81+
import org.springframework.web.util.UriComponentsBuilder;
7282

83+
import static java.nio.charset.StandardCharsets.UTF_8;
7384
import static org.assertj.core.api.Assertions.assertThat;
7485
import static org.mockito.ArgumentMatchers.any;
7586
import static org.mockito.ArgumentMatchers.anyString;
@@ -157,6 +168,20 @@ public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() t
157168
verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
158169
}
159170

171+
@Test
172+
public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception {
173+
this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire();
174+
175+
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
176+
.andReturn();
177+
UriComponents components = UriComponentsBuilder
178+
.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
179+
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
180+
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
181+
String inflated = samlInflate(samlDecode(decoded));
182+
assertThat(inflated).contains("ForceAuthn=\"true\"");
183+
}
184+
160185
private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
161186
// get the OpenSamlAuthenticationProvider
162187
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@@ -275,6 +300,29 @@ Saml2AuthenticationRequestContextResolver resolver() {
275300
}
276301
}
277302

303+
@EnableWebSecurity
304+
@Import(Saml2LoginConfigBeans.class)
305+
static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter {
306+
307+
@Override
308+
protected void configure(HttpSecurity http) throws Exception {
309+
http
310+
.authorizeRequests(authz -> authz
311+
.anyRequest().authenticated()
312+
)
313+
.saml2Login(saml2 -> {});
314+
}
315+
316+
@Bean
317+
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
318+
OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
319+
new OpenSamlAuthenticationRequestFactory();
320+
authenticationRequestFactory.setAuthnRequestConsumerResolver(
321+
context -> authnRequest -> authnRequest.setForceAuthn(true));
322+
return authenticationRequestFactory;
323+
}
324+
}
325+
278326
private static AuthenticationManager getAuthenticationManagerMock(String role) {
279327
return new AuthenticationManager() {
280328

@@ -315,4 +363,23 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository() {
315363
}
316364
}
317365

366+
private static org.apache.commons.codec.binary.Base64 BASE64 =
367+
new org.apache.commons.codec.binary.Base64(0, new byte[]{'\n'});
368+
369+
private static byte[] samlDecode(String s) {
370+
return BASE64.decode(s);
371+
}
372+
373+
private static String samlInflate(byte[] b) {
374+
try {
375+
ByteArrayOutputStream out = new ByteArrayOutputStream();
376+
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
377+
iout.write(b);
378+
iout.finish();
379+
return new String(out.toByteArray(), UTF_8);
380+
}
381+
catch (IOException e) {
382+
throw new Saml2Exception("Unable to inflate string", e);
383+
}
384+
}
318385
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java

+20-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.util.List;
2222
import java.util.Map;
2323
import java.util.UUID;
24+
import java.util.function.Consumer;
25+
import java.util.function.Function;
2426

2527
import org.joda.time.DateTime;
2628
import org.opensaml.saml.common.xml.SAMLConstants;
@@ -43,6 +45,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
4345
private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
4446
private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI;
4547

48+
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
49+
= context -> authnRequest -> {};
50+
4651
@Override
4752
@Deprecated
4853
public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
@@ -95,8 +100,10 @@ public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Sa
95100
}
96101

97102
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
98-
return createAuthnRequest(context.getIssuer(),
103+
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
99104
context.getDestination(), context.getAssertionConsumerServiceUrl());
105+
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
106+
return authnRequest;
100107
}
101108

102109
private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
@@ -114,6 +121,18 @@ private AuthnRequest createAuthnRequest(String issuer, String destination, Strin
114121
return auth;
115122
}
116123

124+
/**
125+
* Set the {@link AuthnRequest} post-processor resolver
126+
*
127+
* @param authnRequestConsumerResolver
128+
* @since 5.4
129+
*/
130+
public void setAuthnRequestConsumerResolver(
131+
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
132+
Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
133+
this.authnRequestConsumerResolver = authnRequestConsumerResolver;
134+
}
135+
117136
/**
118137
* '
119138
* Use this {@link Clock} with {@link Instant#now()} for generating

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java

+36-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package org.springframework.security.saml2.provider.service.authentication;
1818

19+
import java.util.function.Consumer;
20+
import java.util.function.Function;
21+
1922
import org.junit.Assert;
2023
import org.junit.Before;
2124
import org.junit.Rule;
@@ -29,9 +32,13 @@
2932

3033
import static java.nio.charset.StandardCharsets.UTF_8;
3134
import static org.assertj.core.api.Assertions.assertThat;
35+
import static org.assertj.core.api.Assertions.assertThatCode;
3236
import static org.hamcrest.CoreMatchers.containsString;
33-
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
37+
import static org.mockito.Mockito.mock;
38+
import static org.mockito.Mockito.verify;
39+
import static org.mockito.Mockito.when;
3440
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
41+
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
3542
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
3643
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
3744
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
@@ -160,6 +167,34 @@ public void createAuthenticationRequestWhenSetUnsupportredUriThenThrowsIllegalAr
160167
factory.setProtocolBinding("my-invalid-binding");
161168
}
162169

170+
@Test
171+
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
172+
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
173+
mock(Function.class);
174+
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
175+
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
176+
177+
this.factory.createPostAuthenticationRequest(this.context);
178+
verify(authnRequestConsumerResolver).apply(this.context);
179+
}
180+
181+
@Test
182+
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
183+
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
184+
mock(Function.class);
185+
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
186+
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
187+
188+
this.factory.createRedirectAuthenticationRequest(this.context);
189+
verify(authnRequestConsumerResolver).apply(this.context);
190+
}
191+
192+
@Test
193+
public void setAuthnRequestConsumerResolverWhenNullThenException() {
194+
assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
195+
.isInstanceOf(IllegalArgumentException.class);
196+
}
197+
163198
private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) {
164199
AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ?
165200
factory.createRedirectAuthenticationRequest(context) :

0 commit comments

Comments
 (0)