diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index d6cf80068dc..7bb7b8b0ac3 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -19,8 +19,6 @@ import java.util.LinkedHashMap; import java.util.Map; -import jakarta.servlet.Filter; - import org.opensaml.core.Version; import org.springframework.beans.factory.NoSuchBeanDefinitionException; @@ -50,6 +48,7 @@ import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; +import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; @@ -115,9 +114,11 @@ public final class Saml2LoginConfigurer> private String loginPage; - private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI; + private String authenticationRequestUri = "/saml2/authenticate/{registrationId}"; + + private Saml2AuthenticationRequestResolver authenticationRequestResolver; - private AuthenticationRequestEndpointConfig authenticationRequestEndpoint = new AuthenticationRequestEndpointConfig(); + private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI; private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; @@ -176,6 +177,20 @@ public Saml2LoginConfigurer loginPage(String loginPage) { return this; } + /** + * Use this {@link Saml2AuthenticationRequestResolver} for generating SAML 2.0 + * Authentication Requests. + * @param authenticationRequestResolver + * @return the {@link Saml2LoginConfigurer} for further configuration + * @since 5.7 + */ + public Saml2LoginConfigurer authenticationRequestResolver( + Saml2AuthenticationRequestResolver authenticationRequestResolver) { + Assert.notNull(authenticationRequestResolver, "authenticationRequestResolver cannot be null"); + this.authenticationRequestResolver = authenticationRequestResolver; + return this; + } + /** * Specifies the URL to validate the credentials. If specified a custom URL, consider * specifying a custom {@link AuthenticationConverter} via @@ -200,7 +215,7 @@ protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingU /** * {@inheritDoc} - * + *

* Initializes this filter chain for SAML 2 Login. The following actions are taken: *

    *
  • The WebSSO endpoint has CSRF disabled, typically {@code /login/saml2/sso}
  • @@ -226,8 +241,8 @@ public void init(B http) throws Exception { super.init(http); } else { - Map providerUrlMap = getIdentityProviderUrlMap( - this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository); + Map providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri, + this.relyingPartyRegistrationRepository); boolean singleProvider = providerUrlMap.size() == 1; if (singleProvider) { // Setup auto-redirect to provider login page @@ -247,14 +262,16 @@ public void init(B http) throws Exception { /** * {@inheritDoc} - * + *

    * During the {@code configure} phase, a * {@link Saml2WebSsoAuthenticationRequestFilter} is added to handle SAML 2.0 * AuthNRequest redirects */ @Override public void configure(B http) throws Exception { - http.addFilter(this.authenticationRequestEndpoint.build(http)); + Saml2WebSsoAuthenticationRequestFilter filter = getAuthenticationRequestFilter(http); + filter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http)); + http.addFilter(postProcess(filter)); super.configure(http); if (this.authenticationManager == null) { registerDefaultAuthenticationProvider(http); @@ -264,6 +281,11 @@ public void configure(B http) throws Exception { } } + private RelyingPartyRegistrationResolver relyingPartyRegistrationResolver(B http) { + RelyingPartyRegistrationRepository registrations = relyingPartyRegistrationRepository(http); + return new DefaultRelyingPartyRegistrationResolver(registrations); + } + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(B http) { if (this.relyingPartyRegistrationRepository == null) { this.relyingPartyRegistrationRepository = getSharedOrBean(http, RelyingPartyRegistrationRepository.class); @@ -276,6 +298,46 @@ private void setAuthenticationRequestRepository(B http, saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http)); } + private Saml2WebSsoAuthenticationRequestFilter getAuthenticationRequestFilter(B http) { + Saml2AuthenticationRequestResolver authenticationRequestResolver = getAuthenticationRequestResolver(http); + if (authenticationRequestResolver != null) { + return new Saml2WebSsoAuthenticationRequestFilter(authenticationRequestResolver); + } + return new Saml2WebSsoAuthenticationRequestFilter(getAuthenticationRequestContextResolver(http), + getAuthenticationRequestFactory(http)); + } + + private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver(B http) { + if (this.authenticationRequestResolver != null) { + return this.authenticationRequestResolver; + } + return getBeanOrNull(http, Saml2AuthenticationRequestResolver.class); + } + + private Saml2AuthenticationRequestFactory getAuthenticationRequestFactory(B http) { + Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class); + if (resolver != null) { + return resolver; + } + if (version().startsWith("4")) { + return new OpenSaml4AuthenticationRequestFactory(); + } + else { + return new OpenSamlAuthenticationRequestFactory(); + } + } + + private Saml2AuthenticationRequestContextResolver getAuthenticationRequestContextResolver(B http) { + Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, + Saml2AuthenticationRequestContextResolver.class); + if (resolver != null) { + return resolver; + } + RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver( + this.relyingPartyRegistrationRepository); + return new DefaultSaml2AuthenticationRequestContextResolver(registrationResolver); + } + private AuthenticationConverter getAuthenticationConverter(B http) { if (this.authenticationConverter != null) { return this.authenticationConverter; @@ -325,8 +387,8 @@ private void initDefaultLoginFilter(B http) { return; } loginPageGeneratingFilter.setSaml2LoginEnabled(true); - loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap( - this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository)); + loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName( + this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository)); loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage()); loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl()); } @@ -380,46 +442,4 @@ private void setSharedObject(B http, Class clazz, C object) { } } - private final class AuthenticationRequestEndpointConfig { - - private String filterProcessingUrl = "/saml2/authenticate/{registrationId}"; - - private AuthenticationRequestEndpointConfig() { - } - - private Filter build(B http) { - Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http); - Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http); - Saml2AuthenticationRequestRepository repository = getAuthenticationRequestRepository( - http); - Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver, - authenticationRequestResolver); - filter.setAuthenticationRequestRepository(repository); - return postProcess(filter); - } - - private Saml2AuthenticationRequestFactory getResolver(B http) { - Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class); - if (resolver == null) { - if (version().startsWith("4")) { - return new OpenSaml4AuthenticationRequestFactory(); - } - return new OpenSamlAuthenticationRequestFactory(); - } - return resolver; - } - - private Saml2AuthenticationRequestContextResolver getContextResolver(B http) { - Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, - Saml2AuthenticationRequestContextResolver.class); - if (resolver == null) { - RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver( - Saml2LoginConfigurer.this.relyingPartyRegistrationRepository); - return new DefaultSaml2AuthenticationRequestContextResolver(relyingPartyRegistrationResolver); - } - return resolver; - } - - } - } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index 83e555430df..73e8fb27dca 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -80,9 +80,13 @@ import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; +import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; +import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver; +import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.AuthenticationConverter; @@ -104,6 +108,7 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -211,6 +216,41 @@ public void authenticationRequestWhenAuthnRequestContextConverterThenUses() thro assertThat(inflated).contains("ForceAuthn=\"true\""); } + @Test + public void authenticationRequestWhenAuthenticationRequestResolverBeanThenUses() throws Exception { + this.spring.register(CustomAuthenticationRequestResolverBean.class).autowire(); + MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn(); + UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build(); + String samlRequest = components.getQueryParams().getFirst("SAMLRequest"); + String decoded = URLDecoder.decode(samlRequest, "UTF-8"); + String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded)); + assertThat(inflated).contains("ForceAuthn=\"true\""); + } + + @Test + public void authenticationRequestWhenAuthenticationRequestResolverDslThenUses() throws Exception { + this.spring.register(CustomAuthenticationRequestResolverDsl.class).autowire(); + MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn(); + UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build(); + String samlRequest = components.getQueryParams().getFirst("SAMLRequest"); + String decoded = URLDecoder.decode(samlRequest, "UTF-8"); + String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded)); + assertThat(inflated).contains("ForceAuthn=\"true\""); + } + + @Test + public void authenticationRequestWhenAuthenticationRequestResolverAndFactoryThenResolverTakesPrecedence() + throws Exception { + this.spring.register(CustomAuthenticationRequestResolverPrecedence.class).autowire(); + MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn(); + UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build(); + String samlRequest = components.getQueryParams().getFirst("SAMLRequest"); + String decoded = URLDecoder.decode(samlRequest, "UTF-8"); + String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded)); + assertThat(inflated).contains("ForceAuthn=\"true\""); + verifyNoInteractions(this.spring.getContext().getBean(Saml2AuthenticationRequestFactory.class)); + } + @Test public void authenticateWhenCustomAuthenticationConverterThenUses() throws Exception { this.spring.register(CustomAuthenticationConverter.class).autowire(); @@ -506,6 +546,103 @@ Saml2AuthenticationRequestFactory authenticationRequestFactory() { } + @EnableWebSecurity + @Import(Saml2LoginConfigBeans.class) + static class CustomAuthenticationRequestResolverBean { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests((authz) -> authz + .anyRequest().authenticated() + ) + .saml2Login(Customizer.withDefaults()); + // @formatter:on + + return http.build(); + } + + @Bean + Saml2AuthenticationRequestResolver authenticationRequestResolver( + RelyingPartyRegistrationRepository registrations) { + RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver( + registrations); + OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver( + registrationResolver); + delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true)); + return delegate; + } + + } + + @EnableWebSecurity + @Import(Saml2LoginConfigBeans.class) + static class CustomAuthenticationRequestResolverDsl { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http, RelyingPartyRegistrationRepository registrations) + throws Exception { + // @formatter:off + http + .authorizeRequests((authz) -> authz + .anyRequest().authenticated() + ) + .saml2Login((saml2) -> saml2 + .authenticationRequestResolver(authenticationRequestResolver(registrations)) + ); + // @formatter:on + + return http.build(); + } + + Saml2AuthenticationRequestResolver authenticationRequestResolver( + RelyingPartyRegistrationRepository registrations) { + RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver( + registrations); + OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver( + registrationResolver); + delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true)); + return delegate; + } + + } + + @EnableWebSecurity + @Import(Saml2LoginConfigBeans.class) + static class CustomAuthenticationRequestResolverPrecedence { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests((authz) -> authz + .anyRequest().authenticated() + ) + .saml2Login(Customizer.withDefaults()); + // @formatter:on + + return http.build(); + } + + @Bean + Saml2AuthenticationRequestFactory authenticationRequestFactory() { + return mock(Saml2AuthenticationRequestFactory.class); + } + + @Bean + Saml2AuthenticationRequestResolver authenticationRequestResolver( + RelyingPartyRegistrationRepository registrations) { + RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver( + registrations); + OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver( + registrationResolver); + delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true)); + return delegate; + } + + } + @EnableWebSecurity @Import(Saml2LoginConfigBeans.class) static class CustomAuthenticationConverter extends WebSecurityConfigurerAdapter { diff --git a/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc b/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc index ba512b5a4a2..142186ac84c 100644 --- a/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc +++ b/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc @@ -176,92 +176,21 @@ var relyingPartyRegistration: RelyingPartyRegistration? = There are a number of reasons that you may want to adjust an `AuthnRequest`. For example, you may want `ForceAuthN` to be set to `true`, which Spring Security sets to `false` by default. -If you don't need information from the `HttpServletRequest` to make your decision, then the easiest way is to xref:servlet/saml2/login/overview.adoc#servlet-saml2login-opensaml-customization[register a custom `AuthnRequestMarshaller` with OpenSAML]. -This will give you access to post-process the `AuthnRequest` instance before it's serialized. - -But, if you do need something from the request, then you can use create a custom `Saml2AuthenticationRequestContext` implementation and then a `Converter` to build an `AuthnRequest` yourself, like so: - -==== -.Java -[source,java,role="primary"] ----- -@Component -public class AuthnRequestConverter implements - Converter { - - private final AuthnRequestBuilder authnRequestBuilder; - private final IssuerBuilder issuerBuilder; - - // ... constructor - - public AuthnRequest convert(Saml2AuthenticationRequestContext context) { - MySaml2AuthenticationRequestContext myContext = (MySaml2AuthenticationRequestContext) context; - Issuer issuer = issuerBuilder.buildObject(); - issuer.setValue(myContext.getIssuer()); - - AuthnRequest authnRequest = authnRequestBuilder.buildObject(); - authnRequest.setIssuer(issuer); - authnRequest.setDestination(myContext.getDestination()); - authnRequest.setAssertionConsumerServiceURL(myContext.getAssertionConsumerServiceUrl()); - - // ... additional settings - - authRequest.setForceAuthn(myContext.getForceAuthn()); - return authnRequest; - } -} ----- - -.Kotlin -[source,kotlin,role="secondary"] ----- -@Component -class AuthnRequestConverter : Converter { - private val authnRequestBuilder: AuthnRequestBuilder? = null - private val issuerBuilder: IssuerBuilder? = null - - // ... constructor - override fun convert(context: Saml2AuthenticationRequestContext): AuthnRequest { - val myContext: MySaml2AuthenticationRequestContext = context - val issuer: Issuer = issuerBuilder.buildObject() - issuer.value = myContext.getIssuer() - val authnRequest: AuthnRequest = authnRequestBuilder.buildObject() - authnRequest.issuer = issuer - authnRequest.destination = myContext.getDestination() - authnRequest.assertionConsumerServiceURL = myContext.getAssertionConsumerServiceUrl() - - // ... additional settings - authRequest.setForceAuthn(myContext.getForceAuthn()) - return authnRequest - } -} ----- -==== - -Then, you can construct your own `Saml2AuthenticationRequestContextResolver` and `Saml2AuthenticationRequestFactory` and publish them as ``@Bean``s: +You can customize elements of OpenSAML's `AuthnRequest` by publishing an `OpenSaml4AuthenticationRequestResolver` as a `@Bean`, like so: ==== .Java [source,java,role="primary"] ---- @Bean -Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver() { - Saml2AuthenticationRequestContextResolver resolver = - new DefaultSaml2AuthenticationRequestContextResolver(); - return request -> { - Saml2AuthenticationRequestContext context = resolver.resolve(request); - return new MySaml2AuthenticationRequestContext(context, request.getParameter("force") != null); - }; -} - -@Bean -Saml2AuthenticationRequestFactory authenticationRequestFactory( - AuthnRequestConverter authnRequestConverter) { - - OpenSaml4AuthenticationRequestFactory authenticationRequestFactory = - new OpenSaml4AuthenticationRequestFactory(); - authenticationRequestFactory.setAuthenticationRequestContextConverter(authnRequestConverter); - return authenticationRequestFactory; +Saml2AuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistrationRepository registrations) { + RelyingPartyRegistrationResolver registrationResolver = + new DefaultRelyingPartyRegistrationResolver(registrations); + OpenSaml4AuthenticationRequestResolver authenticationRequestResolver = + new OpenSaml4AuthenticationRequestResolver(registrationResolver); + authenticationRequestResolver.setAuthnRequestCustomizer((context) -> context + .getAuthnRequest().setForceAuthn(true)); + return authenticationRequestResolver; } ---- @@ -269,24 +198,14 @@ Saml2AuthenticationRequestFactory authenticationRequestFactory( [source,kotlin,role="secondary"] ---- @Bean -open fun authenticationRequestContextResolver(): Saml2AuthenticationRequestContextResolver { - val resolver: Saml2AuthenticationRequestContextResolver = DefaultSaml2AuthenticationRequestContextResolver() - return Saml2AuthenticationRequestContextResolver { request: HttpServletRequest -> - val context = resolver.resolve(request) - MySaml2AuthenticationRequestContext( - context, - request.getParameter("force") != null - ) - } -} - -@Bean -open fun authenticationRequestFactory( - authnRequestConverter: AuthnRequestConverter? -): Saml2AuthenticationRequestFactory? { - val authenticationRequestFactory = OpenSaml4AuthenticationRequestFactory() - authenticationRequestFactory.setAuthenticationRequestContextConverter(authnRequestConverter) - return authenticationRequestFactory +fun authenticationRequestResolver(registrations : RelyingPartyRegistrationRepository) : Saml2AuthenticationRequestResolver { + val registrationResolver : RelyingPartyRegistrationResolver = + new DefaultRelyingPartyRegistrationResolver(registrations) + val authenticationRequestResolver : OpenSaml4AuthenticationRequestResolver = + new OpenSaml4AuthenticationRequestResolver(registrationResolver) + authenticationRequestResolver.setAuthnRequestCustomizer((context) -> context + .getAuthnRequest().setForceAuthn(true)) + return authenticationRequestResolver } ---- ==== diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java index 5fc84dd078a..d3e6dfd4c11 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java @@ -16,6 +16,7 @@ package org.springframework.security.saml2.provider.service.authentication; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; /** @@ -54,6 +55,17 @@ public static Builder withAuthenticationRequestContext(Saml2AuthenticationReques return new Builder().authenticationRequestUri(context.getDestination()).relayState(context.getRelayState()); } + /** + * Constructs a {@link Builder} from a {@link RelyingPartyRegistration} object. + * @param registration a relying party registration + * @return a modifiable builder object + * @since 5.7 + */ + public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) { + String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation(); + return new Builder().authenticationRequestUri(location); + } + /** * Builder class for a {@link Saml2PostAuthenticationRequest} object. */ diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java index 80fec1d392b..eaafe98a177 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java @@ -16,6 +16,7 @@ package org.springframework.security.saml2.provider.service.authentication; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; /** @@ -77,6 +78,18 @@ public static Builder withAuthenticationRequestContext(Saml2AuthenticationReques return new Builder().authenticationRequestUri(context.getDestination()).relayState(context.getRelayState()); } + /** + * Constructs a {@link Saml2PostAuthenticationRequest.Builder} from a + * {@link RelyingPartyRegistration} object. + * @param registration a relying party registration + * @return a modifiable builder object + * @since 5.7 + */ + public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) { + String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation(); + return new Builder().authenticationRequestUri(location); + } + /** * Builder class for a {@link Saml2RedirectAuthenticationRequest} object. */ diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java index 1d47d544e1e..4a5fea544b4 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java @@ -42,6 +42,7 @@ import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher.MatchResult; @@ -78,11 +79,7 @@ */ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter { - private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver; - - private Saml2AuthenticationRequestFactory authenticationRequestFactory; - - private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); + private final Saml2AuthenticationRequestResolver authenticationRequestResolver; private Saml2AuthenticationRequestRepository authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository(); @@ -129,11 +126,20 @@ private static Saml2AuthenticationRequestFactory requestFactory() { public Saml2WebSsoAuthenticationRequestFilter( Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver, Saml2AuthenticationRequestFactory authenticationRequestFactory) { + this(new FactorySaml2AuthenticationRequestResolver(authenticationRequestContextResolver, + authenticationRequestFactory)); + } - Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null"); - Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null"); - this.authenticationRequestContextResolver = authenticationRequestContextResolver; - this.authenticationRequestFactory = authenticationRequestFactory; + /** + * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the strategy for + * resolving the {@code AuthnRequest} + * @param authenticationRequestResolver the strategy for resolving the + * {@code AuthnRequest} + * @since 5.7 + */ + public Saml2WebSsoAuthenticationRequestFilter(Saml2AuthenticationRequestResolver authenticationRequestResolver) { + Assert.notNull(authenticationRequestResolver, "authenticationRequestResolver cannot be null"); + this.authenticationRequestResolver = authenticationRequestResolver; } /** @@ -146,16 +152,23 @@ public Saml2WebSsoAuthenticationRequestFilter( @Deprecated public void setAuthenticationRequestFactory(Saml2AuthenticationRequestFactory authenticationRequestFactory) { Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null"); - this.authenticationRequestFactory = authenticationRequestFactory; + Assert.isInstanceOf(FactorySaml2AuthenticationRequestResolver.class, this.authenticationRequestResolver, + "You cannot supply both a Saml2AuthenticationRequestResolver and a Saml2AuthenticationRequestFactory"); + ((FactorySaml2AuthenticationRequestResolver) this.authenticationRequestResolver).authenticationRequestFactory = authenticationRequestFactory; } /** * Use the given {@link RequestMatcher} that activates this filter for a given request * @param redirectMatcher the {@link RequestMatcher} to use + * @deprecated Configure the request matcher in an implementation of + * {@link Saml2AuthenticationRequestResolver} instead */ + @Deprecated public void setRedirectMatcher(RequestMatcher redirectMatcher) { Assert.notNull(redirectMatcher, "redirectMatcher cannot be null"); - this.redirectMatcher = redirectMatcher; + Assert.isInstanceOf(FactorySaml2AuthenticationRequestResolver.class, this.authenticationRequestResolver, + "You cannot supply a Saml2AuthenticationRequestResolver and a redirect matcher"); + ((FactorySaml2AuthenticationRequestResolver) this.authenticationRequestResolver).redirectMatcher = redirectMatcher; } /** @@ -174,30 +187,21 @@ public void setAuthenticationRequestRepository( @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - MatchResult matcher = this.redirectMatcher.matcher(request); - if (!matcher.isMatch()) { + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestResolver.resolve(request); + if (authenticationRequest == null) { filterChain.doFilter(request, response); return; } - - Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request); - if (context == null) { - response.sendError(HttpServletResponse.SC_UNAUTHORIZED); - return; - } - RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration(); - if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) { - sendRedirect(request, response, context); + if (authenticationRequest instanceof Saml2RedirectAuthenticationRequest) { + sendRedirect(request, response, (Saml2RedirectAuthenticationRequest) authenticationRequest); } else { - sendPost(request, response, context); + sendPost(request, response, (Saml2PostAuthenticationRequest) authenticationRequest); } } private void sendRedirect(HttpServletRequest request, HttpServletResponse response, - Saml2AuthenticationRequestContext context) throws IOException { - Saml2RedirectAuthenticationRequest authenticationRequest = this.authenticationRequestFactory - .createRedirectAuthenticationRequest(context); + Saml2RedirectAuthenticationRequest authenticationRequest) throws IOException { this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response); UriComponentsBuilder uriBuilder = UriComponentsBuilder .fromUriString(authenticationRequest.getAuthenticationRequestUri()); @@ -218,9 +222,7 @@ private void addParameter(String name, String value, UriComponentsBuilder builde } private void sendPost(HttpServletRequest request, HttpServletResponse response, - Saml2AuthenticationRequestContext context) throws IOException { - Saml2PostAuthenticationRequest authenticationRequest = this.authenticationRequestFactory - .createPostAuthenticationRequest(context); + Saml2PostAuthenticationRequest authenticationRequest) throws IOException { this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response); String html = createSamlPostRequestFormData(authenticationRequest); response.setContentType(MediaType.TEXT_HTML_VALUE); @@ -269,4 +271,41 @@ private String createSamlPostRequestFormData(Saml2PostAuthenticationRequest auth return html.toString(); } + private static class FactorySaml2AuthenticationRequestResolver implements Saml2AuthenticationRequestResolver { + + private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver; + + private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); + + private Saml2AuthenticationRequestFactory authenticationRequestFactory; + + FactorySaml2AuthenticationRequestResolver( + Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver, + Saml2AuthenticationRequestFactory authenticationRequestFactory) { + Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null"); + Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null"); + this.authenticationRequestContextResolver = authenticationRequestContextResolver; + this.authenticationRequestFactory = authenticationRequestFactory; + } + + @Override + public AbstractSaml2AuthenticationRequest resolve(HttpServletRequest request) { + MatchResult matcher = this.redirectMatcher.matcher(request); + if (!matcher.isMatch()) { + return null; + } + Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request); + if (context == null) { + return null; + } + Saml2MessageBinding binding = context.getRelyingPartyRegistration().getAssertingPartyDetails() + .getSingleSignOnServiceBinding(); + if (binding == Saml2MessageBinding.REDIRECT) { + return this.authenticationRequestFactory.createRedirectAuthenticationRequest(context); + } + return this.authenticationRequestFactory.createPostAuthenticationRequest(context); + } + + } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java new file mode 100644 index 00000000000..5f4b3333b22 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.UUID; +import java.util.function.BiConsumer; + +import jakarta.servlet.http.HttpServletRequest; + +import net.shibboleth.utilities.java.support.xml.SerializeSupport; +import org.opensaml.core.config.ConfigurationService; +import org.opensaml.core.xml.config.XMLObjectProviderRegistry; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.saml.saml2.core.AuthnRequest; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder; +import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller; +import org.opensaml.saml.saml2.core.impl.IssuerBuilder; +import org.opensaml.saml.saml2.core.impl.NameIDBuilder; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.OpenSamlInitializationService; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; + +/** + * For internal use only. Intended for consolidating common behavior related to minting a + * SAML 2.0 Authn Request. + */ +class OpenSamlAuthenticationRequestResolver { + + static { + OpenSamlInitializationService.initialize(); + } + + private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); + + private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; + + private final AuthnRequestBuilder authnRequestBuilder; + + private final AuthnRequestMarshaller marshaller; + + private final IssuerBuilder issuerBuilder; + + private final NameIDBuilder nameIdBuilder; + + /** + * Construct a {@link OpenSamlAuthenticationRequestResolver} using the provided + * parameters + * @param relyingPartyRegistrationResolver a strategy for resolving the + * {@link RelyingPartyRegistration} from the {@link HttpServletRequest} + */ + OpenSamlAuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); + this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; + XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); + this.marshaller = (AuthnRequestMarshaller) registry.getMarshallerFactory() + .getMarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); + Assert.notNull(this.marshaller, "logoutRequestMarshaller must be configured in OpenSAML"); + this.authnRequestBuilder = (AuthnRequestBuilder) XMLObjectProviderRegistrySupport.getBuilderFactory() + .getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME); + Assert.notNull(this.authnRequestBuilder, "authnRequestBuilder must be configured in OpenSAML"); + this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME); + Assert.notNull(this.issuerBuilder, "issuerBuilder must be configured in OpenSAML"); + this.nameIdBuilder = (NameIDBuilder) registry.getBuilderFactory().getBuilder(NameID.DEFAULT_ELEMENT_NAME); + Assert.notNull(this.nameIdBuilder, "nameIdBuilder must be configured in OpenSAML"); + } + + T resolve(HttpServletRequest request) { + return resolve(request, (registration, logoutRequest) -> { + }); + } + + T resolve(HttpServletRequest request, + BiConsumer authnRequestConsumer) { + RequestMatcher.MatchResult result = this.requestMatcher.matcher(request); + if (!result.isMatch()) { + return null; + } + String registrationId = result.getVariables().get("registrationId"); + RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request, registrationId); + if (registration == null) { + return null; + } + AuthnRequest authnRequest = this.authnRequestBuilder.buildObject(); + authnRequest.setForceAuthn(Boolean.FALSE); + authnRequest.setIsPassive(Boolean.FALSE); + authnRequest.setProtocolBinding(registration.getAssertionConsumerServiceBinding().getUrn()); + Issuer iss = this.issuerBuilder.buildObject(); + iss.setValue(registration.getEntityId()); + authnRequest.setIssuer(iss); + authnRequest.setDestination(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); + authnRequest.setAssertionConsumerServiceURL(registration.getAssertionConsumerServiceLocation()); + authnRequestConsumer.accept(registration, authnRequest); + if (authnRequest.getID() == null) { + authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1)); + } + String relayState = UUID.randomUUID().toString(); + Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleSignOnServiceBinding(); + if (binding == Saml2MessageBinding.POST) { + if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) { + OpenSamlSigningUtils.sign(authnRequest, registration); + } + String xml = serialize(authnRequest); + String encoded = Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8)); + return (T) Saml2PostAuthenticationRequest.withRelyingPartyRegistration(registration).samlRequest(encoded) + .relayState(relayState).build(); + } + else { + String xml = serialize(authnRequest); + String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); + Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest + .withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState); + if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) { + Map parameters = OpenSamlSigningUtils.sign(registration) + .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded) + .param(Saml2ParameterNames.RELAY_STATE, relayState).parameters(); + builder.sigAlg(parameters.get(Saml2ParameterNames.SIG_ALG)) + .signature(parameters.get(Saml2ParameterNames.SIGNATURE)); + } + return (T) builder.build(); + } + } + + private String serialize(AuthnRequest authnRequest) { + try { + Element element = this.marshaller.marshall(authnRequest); + return SerializeSupport.nodeToString(element); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtils.java new file mode 100644 index 00000000000..e6a8f94108b --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtils.java @@ -0,0 +1,173 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import net.shibboleth.utilities.java.support.resolver.CriteriaSet; +import net.shibboleth.utilities.java.support.xml.SerializeSupport; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * Utility methods for signing SAML components with OpenSAML + * + * For internal use only. + * + * @author Josh Cummings + */ +final class OpenSamlSigningUtils { + + static String serialize(XMLObject object) { + try { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + Element element = marshaller.marshall(object); + return SerializeSupport.nodeToString(element); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + static O sign(O object, RelyingPartyRegistration relyingPartyRegistration) { + SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration); + try { + SignatureSupport.signObject(object, parameters); + return object; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + static QueryParametersPartial sign(RelyingPartyRegistration registration) { + return new QueryParametersPartial(registration); + } + + private static SignatureSigningParameters resolveSigningParameters( + RelyingPartyRegistration relyingPartyRegistration) { + List credentials = resolveSigningCredentials(relyingPartyRegistration); + List algorithms = relyingPartyRegistration.getAssertingPartyDetails().getSigningAlgorithms(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + CriteriaSet criteria = new CriteriaSet(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(algorithms); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private static List resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setEntityId(relyingPartyRegistration.getEntityId()); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + private OpenSamlSigningUtils() { + + } + + static class QueryParametersPartial { + + final RelyingPartyRegistration registration; + + final Map components = new LinkedHashMap<>(); + + QueryParametersPartial(RelyingPartyRegistration registration) { + this.registration = registration; + } + + QueryParametersPartial param(String key, String value) { + this.components.put(key, value); + return this; + } + + Map parameters() { + SignatureSigningParameters parameters = resolveSigningParameters(this.registration); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put("SigAlg", algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put("Signature", b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlVerificationUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlVerificationUtils.java new file mode 100644 index 00000000000..f43d5ddecb1 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlVerificationUtils.java @@ -0,0 +1,207 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; + +import jakarta.servlet.http.HttpServletRequest; + +import net.shibboleth.utilities.java.support.resolver.CriteriaSet; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; + +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.web.util.UriUtils; + +/** + * Utility methods for verifying SAML component signatures with OpenSAML + * + * For internal use only. + * + * @author Josh Cummings + */ + +final class OpenSamlVerificationUtils { + + static VerifierPartial verifySignature(StatusResponseType object, RelyingPartyRegistration registration) { + return new VerifierPartial(object, registration); + } + + static VerifierPartial verifySignature(RequestAbstractType object, RelyingPartyRegistration registration) { + return new VerifierPartial(object, registration); + } + + private OpenSamlVerificationUtils() { + + } + + static class VerifierPartial { + + private final String id; + + private final CriteriaSet criteria; + + private final SignatureTrustEngine trustEngine; + + VerifierPartial(StatusResponseType object, RelyingPartyRegistration registration) { + this.id = object.getID(); + this.criteria = verificationCriteria(object.getIssuer()); + this.trustEngine = trustEngine(registration); + } + + VerifierPartial(RequestAbstractType object, RelyingPartyRegistration registration) { + this.id = object.getID(); + this.criteria = verificationCriteria(object.getIssuer()); + this.trustEngine = trustEngine(registration); + } + + Saml2ResponseValidatorResult redirect(HttpServletRequest request, String objectParameterName) { + RedirectSignature signature = new RedirectSignature(request, objectParameterName); + if (signature.getAlgorithm() == null) { + return Saml2ResponseValidatorResult.failure(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + this.id + "]")); + } + if (!signature.hasSignature()) { + return Saml2ResponseValidatorResult.failure(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + this.id + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = signature.getAlgorithm(); + try { + if (!this.trustEngine.validate(signature.getSignature(), signature.getContent(), algorithmUri, + this.criteria, null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + this.id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + this.id + "]: ")); + } + return Saml2ResponseValidatorResult.failure(errors); + } + + Saml2ResponseValidatorResult post(Signature signature) { + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + this.id + "]: ")); + } + + try { + if (!this.trustEngine.validate(signature, this.criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + this.id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + this.id + "]: ")); + } + + return Saml2ResponseValidatorResult.failure(errors); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + CriteriaSet criteria = new CriteriaSet(); + criteria.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue()))); + criteria.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS))); + criteria.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + return criteria; + } + + private SignatureTrustEngine trustEngine(RelyingPartyRegistration registration) { + Set credentials = new HashSet<>(); + Collection keys = registration.getAssertingPartyDetails() + .getVerificationX509Credentials(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(registration.getAssertingPartyDetails().getEntityId()); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private static class RedirectSignature { + + private final HttpServletRequest request; + + private final String objectParameterName; + + RedirectSignature(HttpServletRequest request, String objectParameterName) { + this.request = request; + this.objectParameterName = objectParameterName; + } + + String getAlgorithm() { + return this.request.getParameter("SigAlg"); + } + + byte[] getContent() { + String query = String.format("%s=%s&SigAlg=%s", this.objectParameterName, + UriUtils.encode(this.request.getParameter(this.objectParameterName), + StandardCharsets.ISO_8859_1), + UriUtils.encode(getAlgorithm(), StandardCharsets.ISO_8859_1)); + return query.getBytes(StandardCharsets.UTF_8); + } + + byte[] getSignature() { + return Saml2Utils.samlDecode(this.request.getParameter("Signature")); + } + + boolean hasSignature() { + return this.request.getParameter("Signature") != null; + } + + } + + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2AuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2AuthenticationRequestResolver.java new file mode 100644 index 00000000000..80bb0ecef6d --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2AuthenticationRequestResolver.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; + +/** + * A strategy for resolving a SAML 2.0 Authentication Request from the + * {@link HttpServletRequest}. + * + * @author Josh Cummings + * @since 5.7 + */ +public interface Saml2AuthenticationRequestResolver { + + T resolve(HttpServletRequest request); + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2Utils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2Utils.java new file mode 100644 index 00000000000..daef78d49ae --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2Utils.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.zip.Deflater; +import java.util.zip.DeflaterOutputStream; +import java.util.zip.Inflater; +import java.util.zip.InflaterOutputStream; + +import org.apache.commons.codec.binary.Base64; + +import org.springframework.security.saml2.Saml2Exception; + +/** + * Utility methods for working with serialized SAML messages. + * + * For internal use only. + * + * @author Josh Cummings + */ +final class Saml2Utils { + + private static Base64 BASE64 = new Base64(0, new byte[] { '\n' }); + + private Saml2Utils() { + } + + static String samlEncode(byte[] b) { + return BASE64.encodeAsString(b); + } + + static byte[] samlDecode(String s) { + return BASE64.decode(s); + } + + static byte[] samlDeflate(String s) { + try { + ByteArrayOutputStream b = new ByteArrayOutputStream(); + DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(Deflater.DEFLATED, true)); + deflater.write(s.getBytes(StandardCharsets.UTF_8)); + deflater.finish(); + return b.toByteArray(); + } + catch (IOException ex) { + throw new Saml2Exception("Unable to deflate string", ex); + } + } + + static String samlInflate(byte[] b) { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); + iout.write(b); + iout.finish(); + return new String(out.toByteArray(), StandardCharsets.UTF_8); + } + catch (IOException ex) { + throw new Saml2Exception("Unable to inflate string", ex); + } + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml3Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml3AuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/opensaml3Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml3AuthenticationRequestResolver.java new file mode 100644 index 00000000000..a6e462e1cae --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml3Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml3AuthenticationRequestResolver.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.time.Clock; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletRequest; + +import org.joda.time.DateTime; +import org.opensaml.saml.saml2.core.AuthnRequest; +import org.opensaml.saml.saml2.core.LogoutRequest; + +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.util.Assert; + +/** + * A strategy for resolving a SAML 2.0 Authentication Request from the + * {@link HttpServletRequest} using OpenSAML. + * + * @author Josh Cummings + * @since 5.7 + * @deprecated OpenSAML 3 has reached end-of-life so this version is no longer recommended + */ +@Deprecated +public final class OpenSaml3AuthenticationRequestResolver implements Saml2AuthenticationRequestResolver { + + private final OpenSamlAuthenticationRequestResolver authnRequestResolver; + + private Consumer contextConsumer = (parameters) -> { + }; + + private Clock clock = Clock.systemUTC(); + + /** + * Construct a {@link OpenSaml3AuthenticationRequestResolver} + */ + public OpenSaml3AuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + this.authnRequestResolver = new OpenSamlAuthenticationRequestResolver(relyingPartyRegistrationResolver); + } + + @Override + public T resolve(HttpServletRequest request) { + return this.authnRequestResolver.resolve(request, (registration, authnRequest) -> { + authnRequest.setIssueInstant(new DateTime(this.clock.millis())); + this.contextConsumer.accept(new AuthnRequestContext(request, registration, authnRequest)); + }); + } + + /** + * Set a {@link Consumer} for modifying the OpenSAML {@link LogoutRequest} + * @param contextConsumer a consumer that accepts an {@link AuthnRequestContext} + */ + public void setAuthnRequestCustomizer(Consumer contextConsumer) { + Assert.notNull(contextConsumer, "contextConsumer cannot be null"); + this.contextConsumer = contextConsumer; + } + + /** + * Use this {@link Clock} for generating the issued {@link DateTime} + * @param clock the {@link Clock} to use + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock must not be null"); + this.clock = clock; + } + + public static final class AuthnRequestContext { + + private final HttpServletRequest request; + + private final RelyingPartyRegistration registration; + + private final AuthnRequest authnRequest; + + public AuthnRequestContext(HttpServletRequest request, RelyingPartyRegistration registration, + AuthnRequest authnRequest) { + this.request = request; + this.registration = registration; + this.authnRequest = authnRequest; + } + + public HttpServletRequest getRequest() { + return this.request; + } + + public RelyingPartyRegistration getRelyingPartyRegistration() { + return this.registration; + } + + public AuthnRequest getAuthnRequest() { + return this.authnRequest; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java new file mode 100644 index 00000000000..0bb7687458b --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.time.Clock; +import java.time.Instant; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletRequest; + +import org.opensaml.saml.saml2.core.AuthnRequest; + +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.util.Assert; + +/** + * A strategy for resolving a SAML 2.0 Authentication Request from the + * {@link HttpServletRequest} using OpenSAML. + * + * @author Josh Cummings + * @since 5.7 + */ +public final class OpenSaml4AuthenticationRequestResolver implements Saml2AuthenticationRequestResolver { + + private final OpenSamlAuthenticationRequestResolver authnRequestResolver; + + private Consumer contextConsumer = (parameters) -> { + }; + + private Clock clock = Clock.systemUTC(); + + /** + * Construct a {@link OpenSaml4AuthenticationRequestResolver} + */ + public OpenSaml4AuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + this.authnRequestResolver = new OpenSamlAuthenticationRequestResolver(relyingPartyRegistrationResolver); + } + + @Override + public T resolve(HttpServletRequest request) { + return this.authnRequestResolver.resolve(request, (registration, authnRequest) -> { + authnRequest.setIssueInstant(Instant.now(this.clock)); + this.contextConsumer.accept(new AuthnRequestContext(request, registration, authnRequest)); + }); + } + + /** + * Set a {@link Consumer} for modifying the OpenSAML {@link AuthnRequest} + * @param contextConsumer a consumer that accepts an {@link AuthnRequestContext} + */ + public void setAuthnRequestCustomizer(Consumer contextConsumer) { + Assert.notNull(contextConsumer, "contextConsumer cannot be null"); + this.contextConsumer = contextConsumer; + } + + /** + * Use this {@link Clock} for generating the issued {@link Instant} + * @param clock the {@link Clock} to use + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock must not be null"); + this.clock = clock; + } + + public static final class AuthnRequestContext { + + private final HttpServletRequest request; + + private final RelyingPartyRegistration registration; + + private final AuthnRequest authnRequest; + + public AuthnRequestContext(HttpServletRequest request, RelyingPartyRegistration registration, + AuthnRequest authnRequest) { + this.request = request; + this.registration = registration; + this.authnRequest = authnRequest; + } + + public HttpServletRequest getRequest() { + return this.request; + } + + public RelyingPartyRegistration getRelyingPartyRegistration() { + return this.registration; + } + + public AuthnRequest getAuthnRequest() { + return this.authnRequest; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java index 7d109e327f3..0e869429a0c 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java @@ -20,6 +20,9 @@ import java.nio.charset.StandardCharsets; import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -43,6 +46,7 @@ import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.web.util.HtmlUtils; @@ -69,6 +73,9 @@ public class Saml2WebSsoAuthenticationRequestFilterTests { private Saml2AuthenticationRequestContextResolver resolver = mock(Saml2AuthenticationRequestContextResolver.class); + private Saml2AuthenticationRequestResolver authenticationRequestResolver = mock( + Saml2AuthenticationRequestResolver.class); + private Saml2AuthenticationRequestRepository authenticationRequestRepository = mock( Saml2AuthenticationRequestRepository.class); @@ -86,7 +93,12 @@ public void setup() { this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); this.request.setPathInfo("/saml2/authenticate/registration-id"); - this.filterChain = new MockFilterChain(); + this.filterChain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) { + ((HttpServletResponse) response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); + } + }; this.rpBuilder = RelyingPartyRegistration.withRegistrationId("registration-id") .providerDetails((c) -> c.entityId("idp-entity-id")).providerDetails((c) -> c.webSsoUrl(IDP_SSO_URL)) .assertionConsumerServiceUrlTemplate("template") @@ -114,6 +126,12 @@ private static Saml2RedirectAuthenticationRequest.Builder redirectAuthentication .authenticationRequestUri(IDP_SSO_URL); } + private static Saml2RedirectAuthenticationRequest.Builder redirectAuthenticationRequest( + RelyingPartyRegistration registration) { + return Saml2RedirectAuthenticationRequest.withRelyingPartyRegistration(registration).samlRequest("request") + .authenticationRequestUri(IDP_SSO_URL); + } + private static Saml2PostAuthenticationRequest.Builder postAuthenticationRequest( Saml2AuthenticationRequestContext context) { return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context).samlRequest("request") @@ -287,4 +305,15 @@ public void doFilterWhenPathStartsWithRegistrationIdThenPosts() throws Exception verify(this.repository).findByRegistrationId("registration-id"); } + @Test + public void doFilterWhenCustomAuthenticationRequestResolverThenUses() throws Exception { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); + Saml2RedirectAuthenticationRequest authenticationRequest = redirectAuthenticationRequest(registration).build(); + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter( + this.authenticationRequestResolver); + given(this.authenticationRequestResolver.resolve(any())).willReturn(authenticationRequest); + filter.doFilterInternal(this.request, this.response, this.filterChain); + verify(this.authenticationRequestResolver).resolve(any()); + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java new file mode 100644 index 00000000000..6902c234a76 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java @@ -0,0 +1,169 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import org.junit.Before; +import org.junit.Test; +import org.opensaml.xmlsec.signature.support.SignatureConstants; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * Tests for {@link OpenSamlAuthenticationRequestResolver} + */ +public class OpenSamlAuthenticationRequestResolverTests { + + private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder; + + @Before + public void setUp() { + this.relyingPartyRegistrationBuilder = TestRelyingPartyRegistrations.relyingPartyRegistration(); + } + + @Test + public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/saml2/authenticate/registration-id"); + RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.build(); + OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); + Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { + assertThat(authnRequest.getAssertionConsumerServiceURL()) + .isEqualTo(registration.getAssertionConsumerServiceLocation()); + assertThat(authnRequest.getProtocolBinding()) + .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn()); + assertThat(authnRequest.getDestination()) + .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); + assertThat(authnRequest.getIssuer().getValue()).isEqualTo(registration.getEntityId()); + }); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isNotNull(); + assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + assertThat(result.getSignature()).isNotEmpty(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + } + + @Test + public void resolveAuthenticationRequestWhenUnsignedRedirectThenRedirectsAndNoSignature() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/saml2/authenticate/registration-id"); + RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder + .assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(false)).build(); + OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); + Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { + assertThat(authnRequest.getAssertionConsumerServiceURL()) + .isEqualTo(registration.getAssertionConsumerServiceLocation()); + assertThat(authnRequest.getProtocolBinding()) + .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn()); + assertThat(authnRequest.getDestination()) + .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); + assertThat(authnRequest.getIssuer().getValue()).isEqualTo(registration.getEntityId()); + }); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isNotNull(); + assertThat(result.getSigAlg()).isNull(); + assertThat(result.getSignature()).isNull(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + } + + @Test + public void resolveAuthenticationRequestWhenSignedThenCredentialIsRequired() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/saml2/authenticate/registration-id"); + Saml2X509Credential credential = TestSaml2X509Credentials.relyingPartyVerifyingCredential(); + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials() + .assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build(); + OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); + assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> resolver.resolve(request, null)); + } + + @Test + public void resolveAuthenticationRequestWhenUnsignedPostThenOnlyPosts() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/saml2/authenticate/registration-id"); + RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.assertingPartyDetails( + (party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST).wantAuthnRequestsSigned(false)) + .build(); + OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); + Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { + assertThat(authnRequest.getAssertionConsumerServiceURL()) + .isEqualTo(registration.getAssertionConsumerServiceLocation()); + assertThat(authnRequest.getProtocolBinding()) + .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn()); + assertThat(authnRequest.getDestination()) + .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); + assertThat(authnRequest.getIssuer().getValue()).isEqualTo(registration.getEntityId()); + }); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isNotNull(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()))).doesNotContain("Signature"); + } + + @Test + public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/saml2/authenticate/registration-id"); + RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder + .assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)).build(); + OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); + Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { + assertThat(authnRequest.getAssertionConsumerServiceURL()) + .isEqualTo(registration.getAssertionConsumerServiceLocation()); + assertThat(authnRequest.getProtocolBinding()) + .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn()); + assertThat(authnRequest.getDestination()) + .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); + assertThat(authnRequest.getIssuer().getValue()).isEqualTo(registration.getEntityId()); + }); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isNotNull(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()))).contains("Signature"); + } + + @Test + public void resolveAuthenticationRequestWhenSHA1SignRequestThenSigns() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/saml2/authenticate/registration-id"); + RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.assertingPartyDetails( + (party) -> party.signingAlgorithms((algs) -> algs.add(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1))) + .build(); + OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); + Saml2RedirectAuthenticationRequest result = resolver.resolve(request, null); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isNotNull(); + assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1); + assertThat(result.getSignature()).isNotNull(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + } + + private OpenSamlAuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistration registration) { + return new OpenSamlAuthenticationRequestResolver((request, id) -> registration); + } + +}