Skip to content

Add Saml2AuthenticationRequestResolver #9277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import java.util.LinkedHashMap;
import java.util.Map;

import jakarta.servlet.Filter;

import org.opensaml.core.Version;

import org.springframework.beans.factory.NoSuchBeanDefinitionException;
Expand Down Expand Up @@ -50,6 +48,7 @@
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
Expand Down Expand Up @@ -115,9 +114,11 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>

private String loginPage;

private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI;
private String authenticationRequestUri = "/saml2/authenticate/{registrationId}";

private Saml2AuthenticationRequestResolver authenticationRequestResolver;

private AuthenticationRequestEndpointConfig authenticationRequestEndpoint = new AuthenticationRequestEndpointConfig();
private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI;

private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;

Expand Down Expand Up @@ -176,6 +177,20 @@ public Saml2LoginConfigurer<B> loginPage(String loginPage) {
return this;
}

/**
* Use this {@link Saml2AuthenticationRequestResolver} for generating SAML 2.0
* Authentication Requests.
* @param authenticationRequestResolver
* @return the {@link Saml2LoginConfigurer} for further configuration
* @since 5.7
*/
public Saml2LoginConfigurer<B> authenticationRequestResolver(
Saml2AuthenticationRequestResolver authenticationRequestResolver) {
Assert.notNull(authenticationRequestResolver, "authenticationRequestResolver cannot be null");
this.authenticationRequestResolver = authenticationRequestResolver;
return this;
}

/**
* Specifies the URL to validate the credentials. If specified a custom URL, consider
* specifying a custom {@link AuthenticationConverter} via
Expand All @@ -200,7 +215,7 @@ protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingU

/**
* {@inheritDoc}
*
* <p>
* Initializes this filter chain for SAML 2 Login. The following actions are taken:
* <ul>
* <li>The WebSSO endpoint has CSRF disabled, typically {@code /login/saml2/sso}</li>
Expand All @@ -226,8 +241,8 @@ public void init(B http) throws Exception {
super.init(http);
}
else {
Map<String, String> providerUrlMap = getIdentityProviderUrlMap(
this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository);
Map<String, String> providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri,
this.relyingPartyRegistrationRepository);
boolean singleProvider = providerUrlMap.size() == 1;
if (singleProvider) {
// Setup auto-redirect to provider login page
Expand All @@ -247,14 +262,16 @@ public void init(B http) throws Exception {

/**
* {@inheritDoc}
*
* <p>
* During the {@code configure} phase, a
* {@link Saml2WebSsoAuthenticationRequestFilter} is added to handle SAML 2.0
* AuthNRequest redirects
*/
@Override
public void configure(B http) throws Exception {
http.addFilter(this.authenticationRequestEndpoint.build(http));
Saml2WebSsoAuthenticationRequestFilter filter = getAuthenticationRequestFilter(http);
filter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
http.addFilter(postProcess(filter));
super.configure(http);
if (this.authenticationManager == null) {
registerDefaultAuthenticationProvider(http);
Expand All @@ -264,6 +281,11 @@ public void configure(B http) throws Exception {
}
}

private RelyingPartyRegistrationResolver relyingPartyRegistrationResolver(B http) {
RelyingPartyRegistrationRepository registrations = relyingPartyRegistrationRepository(http);
return new DefaultRelyingPartyRegistrationResolver(registrations);
}

RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(B http) {
if (this.relyingPartyRegistrationRepository == null) {
this.relyingPartyRegistrationRepository = getSharedOrBean(http, RelyingPartyRegistrationRepository.class);
Expand All @@ -276,6 +298,46 @@ private void setAuthenticationRequestRepository(B http,
saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
}

private Saml2WebSsoAuthenticationRequestFilter getAuthenticationRequestFilter(B http) {
Saml2AuthenticationRequestResolver authenticationRequestResolver = getAuthenticationRequestResolver(http);
if (authenticationRequestResolver != null) {
return new Saml2WebSsoAuthenticationRequestFilter(authenticationRequestResolver);
}
return new Saml2WebSsoAuthenticationRequestFilter(getAuthenticationRequestContextResolver(http),
getAuthenticationRequestFactory(http));
}

private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver(B http) {
if (this.authenticationRequestResolver != null) {
return this.authenticationRequestResolver;
}
return getBeanOrNull(http, Saml2AuthenticationRequestResolver.class);
}

private Saml2AuthenticationRequestFactory getAuthenticationRequestFactory(B http) {
Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class);
if (resolver != null) {
return resolver;
}
if (version().startsWith("4")) {
return new OpenSaml4AuthenticationRequestFactory();
}
else {
return new OpenSamlAuthenticationRequestFactory();
}
}

private Saml2AuthenticationRequestContextResolver getAuthenticationRequestContextResolver(B http) {
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http,
Saml2AuthenticationRequestContextResolver.class);
if (resolver != null) {
return resolver;
}
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
this.relyingPartyRegistrationRepository);
return new DefaultSaml2AuthenticationRequestContextResolver(registrationResolver);
}

private AuthenticationConverter getAuthenticationConverter(B http) {
if (this.authenticationConverter != null) {
return this.authenticationConverter;
Expand Down Expand Up @@ -325,8 +387,8 @@ private void initDefaultLoginFilter(B http) {
return;
}
loginPageGeneratingFilter.setSaml2LoginEnabled(true);
loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap(
this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository));
loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(
this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository));
loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage());
loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl());
}
Expand Down Expand Up @@ -380,46 +442,4 @@ private <C> void setSharedObject(B http, Class<C> clazz, C object) {
}
}

private final class AuthenticationRequestEndpointConfig {

private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";

private AuthenticationRequestEndpointConfig() {
}

private Filter build(B http) {
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getAuthenticationRequestRepository(
http);
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver,
authenticationRequestResolver);
filter.setAuthenticationRequestRepository(repository);
return postProcess(filter);
}

private Saml2AuthenticationRequestFactory getResolver(B http) {
Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class);
if (resolver == null) {
if (version().startsWith("4")) {
return new OpenSaml4AuthenticationRequestFactory();
}
return new OpenSamlAuthenticationRequestFactory();
}
return resolver;
}

private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http,
Saml2AuthenticationRequestContextResolver.class);
if (resolver == null) {
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository);
return new DefaultSaml2AuthenticationRequestContextResolver(relyingPartyRegistrationResolver);
}
return resolver;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,13 @@
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver;
import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationConverter;
Expand All @@ -104,6 +108,7 @@
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
Expand Down Expand Up @@ -211,6 +216,41 @@ public void authenticationRequestWhenAuthnRequestContextConverterThenUses() thro
assertThat(inflated).contains("ForceAuthn=\"true\"");
}

@Test
public void authenticationRequestWhenAuthenticationRequestResolverBeanThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestResolverBean.class).autowire();
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
assertThat(inflated).contains("ForceAuthn=\"true\"");
}

@Test
public void authenticationRequestWhenAuthenticationRequestResolverDslThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestResolverDsl.class).autowire();
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
assertThat(inflated).contains("ForceAuthn=\"true\"");
}

@Test
public void authenticationRequestWhenAuthenticationRequestResolverAndFactoryThenResolverTakesPrecedence()
throws Exception {
this.spring.register(CustomAuthenticationRequestResolverPrecedence.class).autowire();
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
assertThat(inflated).contains("ForceAuthn=\"true\"");
verifyNoInteractions(this.spring.getContext().getBean(Saml2AuthenticationRequestFactory.class));
}

@Test
public void authenticateWhenCustomAuthenticationConverterThenUses() throws Exception {
this.spring.register(CustomAuthenticationConverter.class).autowire();
Expand Down Expand Up @@ -506,6 +546,103 @@ Saml2AuthenticationRequestFactory authenticationRequestFactory() {

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestResolverBean {

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests((authz) -> authz
.anyRequest().authenticated()
)
.saml2Login(Customizer.withDefaults());
// @formatter:on

return http.build();
}

@Bean
Saml2AuthenticationRequestResolver authenticationRequestResolver(
RelyingPartyRegistrationRepository registrations) {
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
registrations);
OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
registrationResolver);
delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
return delegate;
}

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestResolverDsl {

@Bean
SecurityFilterChain filterChain(HttpSecurity http, RelyingPartyRegistrationRepository registrations)
throws Exception {
// @formatter:off
http
.authorizeRequests((authz) -> authz
.anyRequest().authenticated()
)
.saml2Login((saml2) -> saml2
.authenticationRequestResolver(authenticationRequestResolver(registrations))
);
// @formatter:on

return http.build();
}

Saml2AuthenticationRequestResolver authenticationRequestResolver(
RelyingPartyRegistrationRepository registrations) {
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
registrations);
OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
registrationResolver);
delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
return delegate;
}

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestResolverPrecedence {

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests((authz) -> authz
.anyRequest().authenticated()
)
.saml2Login(Customizer.withDefaults());
// @formatter:on

return http.build();
}

@Bean
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
return mock(Saml2AuthenticationRequestFactory.class);
}

@Bean
Saml2AuthenticationRequestResolver authenticationRequestResolver(
RelyingPartyRegistrationRepository registrations) {
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
registrations);
OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
registrationResolver);
delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
return delegate;
}

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationConverter extends WebSecurityConfigurerAdapter {
Expand Down
Loading