Skip to content

Commit 5a25568

Browse files
committed
Add Saml2AuthenticationRequestResolver
Closes gh-10355
1 parent 861368b commit 5a25568

File tree

15 files changed

+1404
-187
lines changed

15 files changed

+1404
-187
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

+74-54
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,8 +19,6 @@
1919
import java.util.LinkedHashMap;
2020
import java.util.Map;
2121

22-
import jakarta.servlet.Filter;
23-
2422
import org.opensaml.core.Version;
2523

2624
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
@@ -50,6 +48,7 @@
5048
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
5149
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
5250
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
51+
import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
5352
import org.springframework.security.web.authentication.AuthenticationConverter;
5453
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
5554
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
@@ -115,9 +114,11 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
115114

116115
private String loginPage;
117116

118-
private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI;
117+
private String authenticationRequestUri = "/saml2/authenticate/{registrationId}";
118+
119+
private Saml2AuthenticationRequestResolver authenticationRequestResolver;
119120

120-
private AuthenticationRequestEndpointConfig authenticationRequestEndpoint = new AuthenticationRequestEndpointConfig();
121+
private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI;
121122

122123
private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
123124

@@ -176,6 +177,20 @@ public Saml2LoginConfigurer<B> loginPage(String loginPage) {
176177
return this;
177178
}
178179

180+
/**
181+
* Use this {@link Saml2AuthenticationRequestResolver} for generating SAML 2.0
182+
* Authentication Requests.
183+
* @param authenticationRequestResolver
184+
* @return the {@link Saml2LoginConfigurer} for further configuration
185+
* @since 5.7
186+
*/
187+
public Saml2LoginConfigurer<B> authenticationRequestResolver(
188+
Saml2AuthenticationRequestResolver authenticationRequestResolver) {
189+
Assert.notNull(authenticationRequestResolver, "authenticationRequestResolver cannot be null");
190+
this.authenticationRequestResolver = authenticationRequestResolver;
191+
return this;
192+
}
193+
179194
/**
180195
* Specifies the URL to validate the credentials. If specified a custom URL, consider
181196
* specifying a custom {@link AuthenticationConverter} via
@@ -200,7 +215,7 @@ protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingU
200215

201216
/**
202217
* {@inheritDoc}
203-
*
218+
* <p>
204219
* Initializes this filter chain for SAML 2 Login. The following actions are taken:
205220
* <ul>
206221
* <li>The WebSSO endpoint has CSRF disabled, typically {@code /login/saml2/sso}</li>
@@ -226,8 +241,8 @@ public void init(B http) throws Exception {
226241
super.init(http);
227242
}
228243
else {
229-
Map<String, String> providerUrlMap = getIdentityProviderUrlMap(
230-
this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository);
244+
Map<String, String> providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri,
245+
this.relyingPartyRegistrationRepository);
231246
boolean singleProvider = providerUrlMap.size() == 1;
232247
if (singleProvider) {
233248
// Setup auto-redirect to provider login page
@@ -247,14 +262,16 @@ public void init(B http) throws Exception {
247262

248263
/**
249264
* {@inheritDoc}
250-
*
265+
* <p>
251266
* During the {@code configure} phase, a
252267
* {@link Saml2WebSsoAuthenticationRequestFilter} is added to handle SAML 2.0
253268
* AuthNRequest redirects
254269
*/
255270
@Override
256271
public void configure(B http) throws Exception {
257-
http.addFilter(this.authenticationRequestEndpoint.build(http));
272+
Saml2WebSsoAuthenticationRequestFilter filter = getAuthenticationRequestFilter(http);
273+
filter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
274+
http.addFilter(postProcess(filter));
258275
super.configure(http);
259276
if (this.authenticationManager == null) {
260277
registerDefaultAuthenticationProvider(http);
@@ -264,6 +281,11 @@ public void configure(B http) throws Exception {
264281
}
265282
}
266283

284+
private RelyingPartyRegistrationResolver relyingPartyRegistrationResolver(B http) {
285+
RelyingPartyRegistrationRepository registrations = relyingPartyRegistrationRepository(http);
286+
return new DefaultRelyingPartyRegistrationResolver(registrations);
287+
}
288+
267289
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(B http) {
268290
if (this.relyingPartyRegistrationRepository == null) {
269291
this.relyingPartyRegistrationRepository = getSharedOrBean(http, RelyingPartyRegistrationRepository.class);
@@ -276,6 +298,46 @@ private void setAuthenticationRequestRepository(B http,
276298
saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
277299
}
278300

301+
private Saml2WebSsoAuthenticationRequestFilter getAuthenticationRequestFilter(B http) {
302+
Saml2AuthenticationRequestResolver authenticationRequestResolver = getAuthenticationRequestResolver(http);
303+
if (authenticationRequestResolver != null) {
304+
return new Saml2WebSsoAuthenticationRequestFilter(authenticationRequestResolver);
305+
}
306+
return new Saml2WebSsoAuthenticationRequestFilter(getAuthenticationRequestContextResolver(http),
307+
getAuthenticationRequestFactory(http));
308+
}
309+
310+
private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver(B http) {
311+
if (this.authenticationRequestResolver != null) {
312+
return this.authenticationRequestResolver;
313+
}
314+
return getBeanOrNull(http, Saml2AuthenticationRequestResolver.class);
315+
}
316+
317+
private Saml2AuthenticationRequestFactory getAuthenticationRequestFactory(B http) {
318+
Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class);
319+
if (resolver != null) {
320+
return resolver;
321+
}
322+
if (version().startsWith("4")) {
323+
return new OpenSaml4AuthenticationRequestFactory();
324+
}
325+
else {
326+
return new OpenSamlAuthenticationRequestFactory();
327+
}
328+
}
329+
330+
private Saml2AuthenticationRequestContextResolver getAuthenticationRequestContextResolver(B http) {
331+
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http,
332+
Saml2AuthenticationRequestContextResolver.class);
333+
if (resolver != null) {
334+
return resolver;
335+
}
336+
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
337+
this.relyingPartyRegistrationRepository);
338+
return new DefaultSaml2AuthenticationRequestContextResolver(registrationResolver);
339+
}
340+
279341
private AuthenticationConverter getAuthenticationConverter(B http) {
280342
if (this.authenticationConverter != null) {
281343
return this.authenticationConverter;
@@ -325,8 +387,8 @@ private void initDefaultLoginFilter(B http) {
325387
return;
326388
}
327389
loginPageGeneratingFilter.setSaml2LoginEnabled(true);
328-
loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap(
329-
this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository));
390+
loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(
391+
this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository));
330392
loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage());
331393
loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl());
332394
}
@@ -380,46 +442,4 @@ private <C> void setSharedObject(B http, Class<C> clazz, C object) {
380442
}
381443
}
382444

383-
private final class AuthenticationRequestEndpointConfig {
384-
385-
private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";
386-
387-
private AuthenticationRequestEndpointConfig() {
388-
}
389-
390-
private Filter build(B http) {
391-
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
392-
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
393-
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getAuthenticationRequestRepository(
394-
http);
395-
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver,
396-
authenticationRequestResolver);
397-
filter.setAuthenticationRequestRepository(repository);
398-
return postProcess(filter);
399-
}
400-
401-
private Saml2AuthenticationRequestFactory getResolver(B http) {
402-
Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class);
403-
if (resolver == null) {
404-
if (version().startsWith("4")) {
405-
return new OpenSaml4AuthenticationRequestFactory();
406-
}
407-
return new OpenSamlAuthenticationRequestFactory();
408-
}
409-
return resolver;
410-
}
411-
412-
private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
413-
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http,
414-
Saml2AuthenticationRequestContextResolver.class);
415-
if (resolver == null) {
416-
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(
417-
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository);
418-
return new DefaultSaml2AuthenticationRequestContextResolver(relyingPartyRegistrationResolver);
419-
}
420-
return resolver;
421-
}
422-
423-
}
424-
425445
}

config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

+138-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -80,9 +80,13 @@
8080
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
8181
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
8282
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
83+
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
84+
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
8385
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
8486
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
8587
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
88+
import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver;
89+
import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
8690
import org.springframework.security.web.FilterChainProxy;
8791
import org.springframework.security.web.SecurityFilterChain;
8892
import org.springframework.security.web.authentication.AuthenticationConverter;
@@ -104,6 +108,7 @@
104108
import static org.mockito.BDDMockito.given;
105109
import static org.mockito.Mockito.mock;
106110
import static org.mockito.Mockito.verify;
111+
import static org.mockito.Mockito.verifyNoInteractions;
107112
import static org.springframework.security.config.Customizer.withDefaults;
108113
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
109114
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@@ -211,6 +216,41 @@ public void authenticationRequestWhenAuthnRequestContextConverterThenUses() thro
211216
assertThat(inflated).contains("ForceAuthn=\"true\"");
212217
}
213218

219+
@Test
220+
public void authenticationRequestWhenAuthenticationRequestResolverBeanThenUses() throws Exception {
221+
this.spring.register(CustomAuthenticationRequestResolverBean.class).autowire();
222+
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
223+
UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
224+
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
225+
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
226+
String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
227+
assertThat(inflated).contains("ForceAuthn=\"true\"");
228+
}
229+
230+
@Test
231+
public void authenticationRequestWhenAuthenticationRequestResolverDslThenUses() throws Exception {
232+
this.spring.register(CustomAuthenticationRequestResolverDsl.class).autowire();
233+
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
234+
UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
235+
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
236+
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
237+
String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
238+
assertThat(inflated).contains("ForceAuthn=\"true\"");
239+
}
240+
241+
@Test
242+
public void authenticationRequestWhenAuthenticationRequestResolverAndFactoryThenResolverTakesPrecedence()
243+
throws Exception {
244+
this.spring.register(CustomAuthenticationRequestResolverPrecedence.class).autowire();
245+
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
246+
UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
247+
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
248+
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
249+
String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
250+
assertThat(inflated).contains("ForceAuthn=\"true\"");
251+
verifyNoInteractions(this.spring.getContext().getBean(Saml2AuthenticationRequestFactory.class));
252+
}
253+
214254
@Test
215255
public void authenticateWhenCustomAuthenticationConverterThenUses() throws Exception {
216256
this.spring.register(CustomAuthenticationConverter.class).autowire();
@@ -506,6 +546,103 @@ Saml2AuthenticationRequestFactory authenticationRequestFactory() {
506546

507547
}
508548

549+
@EnableWebSecurity
550+
@Import(Saml2LoginConfigBeans.class)
551+
static class CustomAuthenticationRequestResolverBean {
552+
553+
@Bean
554+
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
555+
// @formatter:off
556+
http
557+
.authorizeRequests((authz) -> authz
558+
.anyRequest().authenticated()
559+
)
560+
.saml2Login(Customizer.withDefaults());
561+
// @formatter:on
562+
563+
return http.build();
564+
}
565+
566+
@Bean
567+
Saml2AuthenticationRequestResolver authenticationRequestResolver(
568+
RelyingPartyRegistrationRepository registrations) {
569+
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
570+
registrations);
571+
OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
572+
registrationResolver);
573+
delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
574+
return delegate;
575+
}
576+
577+
}
578+
579+
@EnableWebSecurity
580+
@Import(Saml2LoginConfigBeans.class)
581+
static class CustomAuthenticationRequestResolverDsl {
582+
583+
@Bean
584+
SecurityFilterChain filterChain(HttpSecurity http, RelyingPartyRegistrationRepository registrations)
585+
throws Exception {
586+
// @formatter:off
587+
http
588+
.authorizeRequests((authz) -> authz
589+
.anyRequest().authenticated()
590+
)
591+
.saml2Login((saml2) -> saml2
592+
.authenticationRequestResolver(authenticationRequestResolver(registrations))
593+
);
594+
// @formatter:on
595+
596+
return http.build();
597+
}
598+
599+
Saml2AuthenticationRequestResolver authenticationRequestResolver(
600+
RelyingPartyRegistrationRepository registrations) {
601+
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
602+
registrations);
603+
OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
604+
registrationResolver);
605+
delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
606+
return delegate;
607+
}
608+
609+
}
610+
611+
@EnableWebSecurity
612+
@Import(Saml2LoginConfigBeans.class)
613+
static class CustomAuthenticationRequestResolverPrecedence {
614+
615+
@Bean
616+
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
617+
// @formatter:off
618+
http
619+
.authorizeRequests((authz) -> authz
620+
.anyRequest().authenticated()
621+
)
622+
.saml2Login(Customizer.withDefaults());
623+
// @formatter:on
624+
625+
return http.build();
626+
}
627+
628+
@Bean
629+
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
630+
return mock(Saml2AuthenticationRequestFactory.class);
631+
}
632+
633+
@Bean
634+
Saml2AuthenticationRequestResolver authenticationRequestResolver(
635+
RelyingPartyRegistrationRepository registrations) {
636+
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
637+
registrations);
638+
OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
639+
registrationResolver);
640+
delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
641+
return delegate;
642+
}
643+
644+
}
645+
509646
@EnableWebSecurity
510647
@Import(Saml2LoginConfigBeans.class)
511648
static class CustomAuthenticationConverter extends WebSecurityConfigurerAdapter {

0 commit comments

Comments
 (0)