providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri,
+ this.relyingPartyRegistrationRepository);
boolean singleProvider = providerUrlMap.size() == 1;
if (singleProvider) {
// Setup auto-redirect to provider login page
@@ -247,14 +262,16 @@ public void init(B http) throws Exception {
/**
* {@inheritDoc}
- *
+ *
* During the {@code configure} phase, a
* {@link Saml2WebSsoAuthenticationRequestFilter} is added to handle SAML 2.0
* AuthNRequest redirects
*/
@Override
public void configure(B http) throws Exception {
- http.addFilter(this.authenticationRequestEndpoint.build(http));
+ Saml2WebSsoAuthenticationRequestFilter filter = getAuthenticationRequestFilter(http);
+ filter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
+ http.addFilter(postProcess(filter));
super.configure(http);
if (this.authenticationManager == null) {
registerDefaultAuthenticationProvider(http);
@@ -264,6 +281,11 @@ public void configure(B http) throws Exception {
}
}
+ private RelyingPartyRegistrationResolver relyingPartyRegistrationResolver(B http) {
+ RelyingPartyRegistrationRepository registrations = relyingPartyRegistrationRepository(http);
+ return new DefaultRelyingPartyRegistrationResolver(registrations);
+ }
+
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(B http) {
if (this.relyingPartyRegistrationRepository == null) {
this.relyingPartyRegistrationRepository = getSharedOrBean(http, RelyingPartyRegistrationRepository.class);
@@ -276,6 +298,46 @@ private void setAuthenticationRequestRepository(B http,
saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
}
+ private Saml2WebSsoAuthenticationRequestFilter getAuthenticationRequestFilter(B http) {
+ Saml2AuthenticationRequestResolver authenticationRequestResolver = getAuthenticationRequestResolver(http);
+ if (authenticationRequestResolver != null) {
+ return new Saml2WebSsoAuthenticationRequestFilter(authenticationRequestResolver);
+ }
+ return new Saml2WebSsoAuthenticationRequestFilter(getAuthenticationRequestContextResolver(http),
+ getAuthenticationRequestFactory(http));
+ }
+
+ private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver(B http) {
+ if (this.authenticationRequestResolver != null) {
+ return this.authenticationRequestResolver;
+ }
+ return getBeanOrNull(http, Saml2AuthenticationRequestResolver.class);
+ }
+
+ private Saml2AuthenticationRequestFactory getAuthenticationRequestFactory(B http) {
+ Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class);
+ if (resolver != null) {
+ return resolver;
+ }
+ if (version().startsWith("4")) {
+ return new OpenSaml4AuthenticationRequestFactory();
+ }
+ else {
+ return new OpenSamlAuthenticationRequestFactory();
+ }
+ }
+
+ private Saml2AuthenticationRequestContextResolver getAuthenticationRequestContextResolver(B http) {
+ Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http,
+ Saml2AuthenticationRequestContextResolver.class);
+ if (resolver != null) {
+ return resolver;
+ }
+ RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
+ this.relyingPartyRegistrationRepository);
+ return new DefaultSaml2AuthenticationRequestContextResolver(registrationResolver);
+ }
+
private AuthenticationConverter getAuthenticationConverter(B http) {
if (this.authenticationConverter != null) {
return this.authenticationConverter;
@@ -325,8 +387,8 @@ private void initDefaultLoginFilter(B http) {
return;
}
loginPageGeneratingFilter.setSaml2LoginEnabled(true);
- loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap(
- this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository));
+ loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(
+ this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository));
loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage());
loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl());
}
@@ -380,46 +442,4 @@ private void setSharedObject(B http, Class clazz, C object) {
}
}
- private final class AuthenticationRequestEndpointConfig {
-
- private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";
-
- private AuthenticationRequestEndpointConfig() {
- }
-
- private Filter build(B http) {
- Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
- Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
- Saml2AuthenticationRequestRepository repository = getAuthenticationRequestRepository(
- http);
- Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver,
- authenticationRequestResolver);
- filter.setAuthenticationRequestRepository(repository);
- return postProcess(filter);
- }
-
- private Saml2AuthenticationRequestFactory getResolver(B http) {
- Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class);
- if (resolver == null) {
- if (version().startsWith("4")) {
- return new OpenSaml4AuthenticationRequestFactory();
- }
- return new OpenSamlAuthenticationRequestFactory();
- }
- return resolver;
- }
-
- private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
- Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http,
- Saml2AuthenticationRequestContextResolver.class);
- if (resolver == null) {
- RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(
- Saml2LoginConfigurer.this.relyingPartyRegistrationRepository);
- return new DefaultSaml2AuthenticationRequestContextResolver(relyingPartyRegistrationResolver);
- }
- return resolver;
- }
-
- }
-
}
diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java
index 83e555430df..73e8fb27dca 100644
--- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java
+++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java
@@ -80,9 +80,13 @@
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
+import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
+import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver;
+import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationConverter;
@@ -104,6 +108,7 @@
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@@ -211,6 +216,41 @@ public void authenticationRequestWhenAuthnRequestContextConverterThenUses() thro
assertThat(inflated).contains("ForceAuthn=\"true\"");
}
+ @Test
+ public void authenticationRequestWhenAuthenticationRequestResolverBeanThenUses() throws Exception {
+ this.spring.register(CustomAuthenticationRequestResolverBean.class).autowire();
+ MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
+ UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
+ String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
+ String decoded = URLDecoder.decode(samlRequest, "UTF-8");
+ String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
+ assertThat(inflated).contains("ForceAuthn=\"true\"");
+ }
+
+ @Test
+ public void authenticationRequestWhenAuthenticationRequestResolverDslThenUses() throws Exception {
+ this.spring.register(CustomAuthenticationRequestResolverDsl.class).autowire();
+ MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
+ UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
+ String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
+ String decoded = URLDecoder.decode(samlRequest, "UTF-8");
+ String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
+ assertThat(inflated).contains("ForceAuthn=\"true\"");
+ }
+
+ @Test
+ public void authenticationRequestWhenAuthenticationRequestResolverAndFactoryThenResolverTakesPrecedence()
+ throws Exception {
+ this.spring.register(CustomAuthenticationRequestResolverPrecedence.class).autowire();
+ MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
+ UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
+ String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
+ String decoded = URLDecoder.decode(samlRequest, "UTF-8");
+ String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
+ assertThat(inflated).contains("ForceAuthn=\"true\"");
+ verifyNoInteractions(this.spring.getContext().getBean(Saml2AuthenticationRequestFactory.class));
+ }
+
@Test
public void authenticateWhenCustomAuthenticationConverterThenUses() throws Exception {
this.spring.register(CustomAuthenticationConverter.class).autowire();
@@ -506,6 +546,103 @@ Saml2AuthenticationRequestFactory authenticationRequestFactory() {
}
+ @EnableWebSecurity
+ @Import(Saml2LoginConfigBeans.class)
+ static class CustomAuthenticationRequestResolverBean {
+
+ @Bean
+ SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
+ // @formatter:off
+ http
+ .authorizeRequests((authz) -> authz
+ .anyRequest().authenticated()
+ )
+ .saml2Login(Customizer.withDefaults());
+ // @formatter:on
+
+ return http.build();
+ }
+
+ @Bean
+ Saml2AuthenticationRequestResolver authenticationRequestResolver(
+ RelyingPartyRegistrationRepository registrations) {
+ RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
+ registrations);
+ OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
+ registrationResolver);
+ delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
+ return delegate;
+ }
+
+ }
+
+ @EnableWebSecurity
+ @Import(Saml2LoginConfigBeans.class)
+ static class CustomAuthenticationRequestResolverDsl {
+
+ @Bean
+ SecurityFilterChain filterChain(HttpSecurity http, RelyingPartyRegistrationRepository registrations)
+ throws Exception {
+ // @formatter:off
+ http
+ .authorizeRequests((authz) -> authz
+ .anyRequest().authenticated()
+ )
+ .saml2Login((saml2) -> saml2
+ .authenticationRequestResolver(authenticationRequestResolver(registrations))
+ );
+ // @formatter:on
+
+ return http.build();
+ }
+
+ Saml2AuthenticationRequestResolver authenticationRequestResolver(
+ RelyingPartyRegistrationRepository registrations) {
+ RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
+ registrations);
+ OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
+ registrationResolver);
+ delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
+ return delegate;
+ }
+
+ }
+
+ @EnableWebSecurity
+ @Import(Saml2LoginConfigBeans.class)
+ static class CustomAuthenticationRequestResolverPrecedence {
+
+ @Bean
+ SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
+ // @formatter:off
+ http
+ .authorizeRequests((authz) -> authz
+ .anyRequest().authenticated()
+ )
+ .saml2Login(Customizer.withDefaults());
+ // @formatter:on
+
+ return http.build();
+ }
+
+ @Bean
+ Saml2AuthenticationRequestFactory authenticationRequestFactory() {
+ return mock(Saml2AuthenticationRequestFactory.class);
+ }
+
+ @Bean
+ Saml2AuthenticationRequestResolver authenticationRequestResolver(
+ RelyingPartyRegistrationRepository registrations) {
+ RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
+ registrations);
+ OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
+ registrationResolver);
+ delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
+ return delegate;
+ }
+
+ }
+
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationConverter extends WebSecurityConfigurerAdapter {
diff --git a/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc b/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc
index ba512b5a4a2..142186ac84c 100644
--- a/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc
+++ b/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc
@@ -176,92 +176,21 @@ var relyingPartyRegistration: RelyingPartyRegistration? =
There are a number of reasons that you may want to adjust an `AuthnRequest`.
For example, you may want `ForceAuthN` to be set to `true`, which Spring Security sets to `false` by default.
-If you don't need information from the `HttpServletRequest` to make your decision, then the easiest way is to xref:servlet/saml2/login/overview.adoc#servlet-saml2login-opensaml-customization[register a custom `AuthnRequestMarshaller` with OpenSAML].
-This will give you access to post-process the `AuthnRequest` instance before it's serialized.
-
-But, if you do need something from the request, then you can use create a custom `Saml2AuthenticationRequestContext` implementation and then a `Converter` to build an `AuthnRequest` yourself, like so:
-
-====
-.Java
-[source,java,role="primary"]
-----
-@Component
-public class AuthnRequestConverter implements
- Converter {
-
- private final AuthnRequestBuilder authnRequestBuilder;
- private final IssuerBuilder issuerBuilder;
-
- // ... constructor
-
- public AuthnRequest convert(Saml2AuthenticationRequestContext context) {
- MySaml2AuthenticationRequestContext myContext = (MySaml2AuthenticationRequestContext) context;
- Issuer issuer = issuerBuilder.buildObject();
- issuer.setValue(myContext.getIssuer());
-
- AuthnRequest authnRequest = authnRequestBuilder.buildObject();
- authnRequest.setIssuer(issuer);
- authnRequest.setDestination(myContext.getDestination());
- authnRequest.setAssertionConsumerServiceURL(myContext.getAssertionConsumerServiceUrl());
-
- // ... additional settings
-
- authRequest.setForceAuthn(myContext.getForceAuthn());
- return authnRequest;
- }
-}
-----
-
-.Kotlin
-[source,kotlin,role="secondary"]
-----
-@Component
-class AuthnRequestConverter : Converter {
- private val authnRequestBuilder: AuthnRequestBuilder? = null
- private val issuerBuilder: IssuerBuilder? = null
-
- // ... constructor
- override fun convert(context: Saml2AuthenticationRequestContext): AuthnRequest {
- val myContext: MySaml2AuthenticationRequestContext = context
- val issuer: Issuer = issuerBuilder.buildObject()
- issuer.value = myContext.getIssuer()
- val authnRequest: AuthnRequest = authnRequestBuilder.buildObject()
- authnRequest.issuer = issuer
- authnRequest.destination = myContext.getDestination()
- authnRequest.assertionConsumerServiceURL = myContext.getAssertionConsumerServiceUrl()
-
- // ... additional settings
- authRequest.setForceAuthn(myContext.getForceAuthn())
- return authnRequest
- }
-}
-----
-====
-
-Then, you can construct your own `Saml2AuthenticationRequestContextResolver` and `Saml2AuthenticationRequestFactory` and publish them as ``@Bean``s:
+You can customize elements of OpenSAML's `AuthnRequest` by publishing an `OpenSaml4AuthenticationRequestResolver` as a `@Bean`, like so:
====
.Java
[source,java,role="primary"]
----
@Bean
-Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver() {
- Saml2AuthenticationRequestContextResolver resolver =
- new DefaultSaml2AuthenticationRequestContextResolver();
- return request -> {
- Saml2AuthenticationRequestContext context = resolver.resolve(request);
- return new MySaml2AuthenticationRequestContext(context, request.getParameter("force") != null);
- };
-}
-
-@Bean
-Saml2AuthenticationRequestFactory authenticationRequestFactory(
- AuthnRequestConverter authnRequestConverter) {
-
- OpenSaml4AuthenticationRequestFactory authenticationRequestFactory =
- new OpenSaml4AuthenticationRequestFactory();
- authenticationRequestFactory.setAuthenticationRequestContextConverter(authnRequestConverter);
- return authenticationRequestFactory;
+Saml2AuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistrationRepository registrations) {
+ RelyingPartyRegistrationResolver registrationResolver =
+ new DefaultRelyingPartyRegistrationResolver(registrations);
+ OpenSaml4AuthenticationRequestResolver authenticationRequestResolver =
+ new OpenSaml4AuthenticationRequestResolver(registrationResolver);
+ authenticationRequestResolver.setAuthnRequestCustomizer((context) -> context
+ .getAuthnRequest().setForceAuthn(true));
+ return authenticationRequestResolver;
}
----
@@ -269,24 +198,14 @@ Saml2AuthenticationRequestFactory authenticationRequestFactory(
[source,kotlin,role="secondary"]
----
@Bean
-open fun authenticationRequestContextResolver(): Saml2AuthenticationRequestContextResolver {
- val resolver: Saml2AuthenticationRequestContextResolver = DefaultSaml2AuthenticationRequestContextResolver()
- return Saml2AuthenticationRequestContextResolver { request: HttpServletRequest ->
- val context = resolver.resolve(request)
- MySaml2AuthenticationRequestContext(
- context,
- request.getParameter("force") != null
- )
- }
-}
-
-@Bean
-open fun authenticationRequestFactory(
- authnRequestConverter: AuthnRequestConverter?
-): Saml2AuthenticationRequestFactory? {
- val authenticationRequestFactory = OpenSaml4AuthenticationRequestFactory()
- authenticationRequestFactory.setAuthenticationRequestContextConverter(authnRequestConverter)
- return authenticationRequestFactory
+fun authenticationRequestResolver(registrations : RelyingPartyRegistrationRepository) : Saml2AuthenticationRequestResolver {
+ val registrationResolver : RelyingPartyRegistrationResolver =
+ new DefaultRelyingPartyRegistrationResolver(registrations)
+ val authenticationRequestResolver : OpenSaml4AuthenticationRequestResolver =
+ new OpenSaml4AuthenticationRequestResolver(registrationResolver)
+ authenticationRequestResolver.setAuthnRequestCustomizer((context) -> context
+ .getAuthnRequest().setForceAuthn(true))
+ return authenticationRequestResolver
}
----
====
diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java
index 5fc84dd078a..d3e6dfd4c11 100644
--- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java
+++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java
@@ -16,6 +16,7 @@
package org.springframework.security.saml2.provider.service.authentication;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
/**
@@ -54,6 +55,17 @@ public static Builder withAuthenticationRequestContext(Saml2AuthenticationReques
return new Builder().authenticationRequestUri(context.getDestination()).relayState(context.getRelayState());
}
+ /**
+ * Constructs a {@link Builder} from a {@link RelyingPartyRegistration} object.
+ * @param registration a relying party registration
+ * @return a modifiable builder object
+ * @since 5.7
+ */
+ public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
+ String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
+ return new Builder().authenticationRequestUri(location);
+ }
+
/**
* Builder class for a {@link Saml2PostAuthenticationRequest} object.
*/
diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java
index 80fec1d392b..eaafe98a177 100644
--- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java
+++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java
@@ -16,6 +16,7 @@
package org.springframework.security.saml2.provider.service.authentication;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
/**
@@ -77,6 +78,18 @@ public static Builder withAuthenticationRequestContext(Saml2AuthenticationReques
return new Builder().authenticationRequestUri(context.getDestination()).relayState(context.getRelayState());
}
+ /**
+ * Constructs a {@link Saml2PostAuthenticationRequest.Builder} from a
+ * {@link RelyingPartyRegistration} object.
+ * @param registration a relying party registration
+ * @return a modifiable builder object
+ * @since 5.7
+ */
+ public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
+ String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
+ return new Builder().authenticationRequestUri(location);
+ }
+
/**
* Builder class for a {@link Saml2RedirectAuthenticationRequest} object.
*/
diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java
index 1d47d544e1e..4a5fea544b4 100644
--- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java
+++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java
@@ -42,6 +42,7 @@
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher.MatchResult;
@@ -78,11 +79,7 @@
*/
public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter {
- private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver;
-
- private Saml2AuthenticationRequestFactory authenticationRequestFactory;
-
- private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
+ private final Saml2AuthenticationRequestResolver authenticationRequestResolver;
private Saml2AuthenticationRequestRepository authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
@@ -129,11 +126,20 @@ private static Saml2AuthenticationRequestFactory requestFactory() {
public Saml2WebSsoAuthenticationRequestFilter(
Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver,
Saml2AuthenticationRequestFactory authenticationRequestFactory) {
+ this(new FactorySaml2AuthenticationRequestResolver(authenticationRequestContextResolver,
+ authenticationRequestFactory));
+ }
- Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
- Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
- this.authenticationRequestContextResolver = authenticationRequestContextResolver;
- this.authenticationRequestFactory = authenticationRequestFactory;
+ /**
+ * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the strategy for
+ * resolving the {@code AuthnRequest}
+ * @param authenticationRequestResolver the strategy for resolving the
+ * {@code AuthnRequest}
+ * @since 5.7
+ */
+ public Saml2WebSsoAuthenticationRequestFilter(Saml2AuthenticationRequestResolver authenticationRequestResolver) {
+ Assert.notNull(authenticationRequestResolver, "authenticationRequestResolver cannot be null");
+ this.authenticationRequestResolver = authenticationRequestResolver;
}
/**
@@ -146,16 +152,23 @@ public Saml2WebSsoAuthenticationRequestFilter(
@Deprecated
public void setAuthenticationRequestFactory(Saml2AuthenticationRequestFactory authenticationRequestFactory) {
Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
- this.authenticationRequestFactory = authenticationRequestFactory;
+ Assert.isInstanceOf(FactorySaml2AuthenticationRequestResolver.class, this.authenticationRequestResolver,
+ "You cannot supply both a Saml2AuthenticationRequestResolver and a Saml2AuthenticationRequestFactory");
+ ((FactorySaml2AuthenticationRequestResolver) this.authenticationRequestResolver).authenticationRequestFactory = authenticationRequestFactory;
}
/**
* Use the given {@link RequestMatcher} that activates this filter for a given request
* @param redirectMatcher the {@link RequestMatcher} to use
+ * @deprecated Configure the request matcher in an implementation of
+ * {@link Saml2AuthenticationRequestResolver} instead
*/
+ @Deprecated
public void setRedirectMatcher(RequestMatcher redirectMatcher) {
Assert.notNull(redirectMatcher, "redirectMatcher cannot be null");
- this.redirectMatcher = redirectMatcher;
+ Assert.isInstanceOf(FactorySaml2AuthenticationRequestResolver.class, this.authenticationRequestResolver,
+ "You cannot supply a Saml2AuthenticationRequestResolver and a redirect matcher");
+ ((FactorySaml2AuthenticationRequestResolver) this.authenticationRequestResolver).redirectMatcher = redirectMatcher;
}
/**
@@ -174,30 +187,21 @@ public void setAuthenticationRequestRepository(
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
- MatchResult matcher = this.redirectMatcher.matcher(request);
- if (!matcher.isMatch()) {
+ AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestResolver.resolve(request);
+ if (authenticationRequest == null) {
filterChain.doFilter(request, response);
return;
}
-
- Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request);
- if (context == null) {
- response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
- return;
- }
- RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
- if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
- sendRedirect(request, response, context);
+ if (authenticationRequest instanceof Saml2RedirectAuthenticationRequest) {
+ sendRedirect(request, response, (Saml2RedirectAuthenticationRequest) authenticationRequest);
}
else {
- sendPost(request, response, context);
+ sendPost(request, response, (Saml2PostAuthenticationRequest) authenticationRequest);
}
}
private void sendRedirect(HttpServletRequest request, HttpServletResponse response,
- Saml2AuthenticationRequestContext context) throws IOException {
- Saml2RedirectAuthenticationRequest authenticationRequest = this.authenticationRequestFactory
- .createRedirectAuthenticationRequest(context);
+ Saml2RedirectAuthenticationRequest authenticationRequest) throws IOException {
this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response);
UriComponentsBuilder uriBuilder = UriComponentsBuilder
.fromUriString(authenticationRequest.getAuthenticationRequestUri());
@@ -218,9 +222,7 @@ private void addParameter(String name, String value, UriComponentsBuilder builde
}
private void sendPost(HttpServletRequest request, HttpServletResponse response,
- Saml2AuthenticationRequestContext context) throws IOException {
- Saml2PostAuthenticationRequest authenticationRequest = this.authenticationRequestFactory
- .createPostAuthenticationRequest(context);
+ Saml2PostAuthenticationRequest authenticationRequest) throws IOException {
this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response);
String html = createSamlPostRequestFormData(authenticationRequest);
response.setContentType(MediaType.TEXT_HTML_VALUE);
@@ -269,4 +271,41 @@ private String createSamlPostRequestFormData(Saml2PostAuthenticationRequest auth
return html.toString();
}
+ private static class FactorySaml2AuthenticationRequestResolver implements Saml2AuthenticationRequestResolver {
+
+ private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver;
+
+ private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
+
+ private Saml2AuthenticationRequestFactory authenticationRequestFactory;
+
+ FactorySaml2AuthenticationRequestResolver(
+ Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver,
+ Saml2AuthenticationRequestFactory authenticationRequestFactory) {
+ Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
+ Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
+ this.authenticationRequestContextResolver = authenticationRequestContextResolver;
+ this.authenticationRequestFactory = authenticationRequestFactory;
+ }
+
+ @Override
+ public AbstractSaml2AuthenticationRequest resolve(HttpServletRequest request) {
+ MatchResult matcher = this.redirectMatcher.matcher(request);
+ if (!matcher.isMatch()) {
+ return null;
+ }
+ Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request);
+ if (context == null) {
+ return null;
+ }
+ Saml2MessageBinding binding = context.getRelyingPartyRegistration().getAssertingPartyDetails()
+ .getSingleSignOnServiceBinding();
+ if (binding == Saml2MessageBinding.REDIRECT) {
+ return this.authenticationRequestFactory.createRedirectAuthenticationRequest(context);
+ }
+ return this.authenticationRequestFactory.createPostAuthenticationRequest(context);
+ }
+
+ }
+
}
diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java
new file mode 100644
index 00000000000..5f4b3333b22
--- /dev/null
+++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java
@@ -0,0 +1,163 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web.authentication;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Map;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+
+import jakarta.servlet.http.HttpServletRequest;
+
+import net.shibboleth.utilities.java.support.xml.SerializeSupport;
+import org.opensaml.core.config.ConfigurationService;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
+import org.opensaml.core.xml.io.MarshallingException;
+import org.opensaml.saml.saml2.core.AuthnRequest;
+import org.opensaml.saml.saml2.core.Issuer;
+import org.opensaml.saml.saml2.core.NameID;
+import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
+import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller;
+import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
+import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
+import org.w3c.dom.Element;
+
+import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.OpenSamlInitializationService;
+import org.springframework.security.saml2.core.Saml2ParameterNames;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
+import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
+import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.util.Assert;
+
+/**
+ * For internal use only. Intended for consolidating common behavior related to minting a
+ * SAML 2.0 Authn Request.
+ */
+class OpenSamlAuthenticationRequestResolver {
+
+ static {
+ OpenSamlInitializationService.initialize();
+ }
+
+ private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
+
+ private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
+
+ private final AuthnRequestBuilder authnRequestBuilder;
+
+ private final AuthnRequestMarshaller marshaller;
+
+ private final IssuerBuilder issuerBuilder;
+
+ private final NameIDBuilder nameIdBuilder;
+
+ /**
+ * Construct a {@link OpenSamlAuthenticationRequestResolver} using the provided
+ * parameters
+ * @param relyingPartyRegistrationResolver a strategy for resolving the
+ * {@link RelyingPartyRegistration} from the {@link HttpServletRequest}
+ */
+ OpenSamlAuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
+ Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
+ this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
+ XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
+ this.marshaller = (AuthnRequestMarshaller) registry.getMarshallerFactory()
+ .getMarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
+ Assert.notNull(this.marshaller, "logoutRequestMarshaller must be configured in OpenSAML");
+ this.authnRequestBuilder = (AuthnRequestBuilder) XMLObjectProviderRegistrySupport.getBuilderFactory()
+ .getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME);
+ Assert.notNull(this.authnRequestBuilder, "authnRequestBuilder must be configured in OpenSAML");
+ this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
+ Assert.notNull(this.issuerBuilder, "issuerBuilder must be configured in OpenSAML");
+ this.nameIdBuilder = (NameIDBuilder) registry.getBuilderFactory().getBuilder(NameID.DEFAULT_ELEMENT_NAME);
+ Assert.notNull(this.nameIdBuilder, "nameIdBuilder must be configured in OpenSAML");
+ }
+
+ T resolve(HttpServletRequest request) {
+ return resolve(request, (registration, logoutRequest) -> {
+ });
+ }
+
+ T resolve(HttpServletRequest request,
+ BiConsumer authnRequestConsumer) {
+ RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
+ if (!result.isMatch()) {
+ return null;
+ }
+ String registrationId = result.getVariables().get("registrationId");
+ RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request, registrationId);
+ if (registration == null) {
+ return null;
+ }
+ AuthnRequest authnRequest = this.authnRequestBuilder.buildObject();
+ authnRequest.setForceAuthn(Boolean.FALSE);
+ authnRequest.setIsPassive(Boolean.FALSE);
+ authnRequest.setProtocolBinding(registration.getAssertionConsumerServiceBinding().getUrn());
+ Issuer iss = this.issuerBuilder.buildObject();
+ iss.setValue(registration.getEntityId());
+ authnRequest.setIssuer(iss);
+ authnRequest.setDestination(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation());
+ authnRequest.setAssertionConsumerServiceURL(registration.getAssertionConsumerServiceLocation());
+ authnRequestConsumer.accept(registration, authnRequest);
+ if (authnRequest.getID() == null) {
+ authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));
+ }
+ String relayState = UUID.randomUUID().toString();
+ Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleSignOnServiceBinding();
+ if (binding == Saml2MessageBinding.POST) {
+ if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
+ OpenSamlSigningUtils.sign(authnRequest, registration);
+ }
+ String xml = serialize(authnRequest);
+ String encoded = Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8));
+ return (T) Saml2PostAuthenticationRequest.withRelyingPartyRegistration(registration).samlRequest(encoded)
+ .relayState(relayState).build();
+ }
+ else {
+ String xml = serialize(authnRequest);
+ String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
+ Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest
+ .withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState);
+ if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
+ Map parameters = OpenSamlSigningUtils.sign(registration)
+ .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded)
+ .param(Saml2ParameterNames.RELAY_STATE, relayState).parameters();
+ builder.sigAlg(parameters.get(Saml2ParameterNames.SIG_ALG))
+ .signature(parameters.get(Saml2ParameterNames.SIGNATURE));
+ }
+ return (T) builder.build();
+ }
+ }
+
+ private String serialize(AuthnRequest authnRequest) {
+ try {
+ Element element = this.marshaller.marshall(authnRequest);
+ return SerializeSupport.nodeToString(element);
+ }
+ catch (MarshallingException ex) {
+ throw new Saml2Exception(ex);
+ }
+ }
+
+}
diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtils.java
new file mode 100644
index 00000000000..e6a8f94108b
--- /dev/null
+++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlSigningUtils.java
@@ -0,0 +1,173 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web.authentication;
+
+import java.nio.charset.StandardCharsets;
+import java.security.PrivateKey;
+import java.security.cert.X509Certificate;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
+import net.shibboleth.utilities.java.support.xml.SerializeSupport;
+import org.opensaml.core.xml.XMLObject;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
+import org.opensaml.core.xml.io.Marshaller;
+import org.opensaml.core.xml.io.MarshallingException;
+import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver;
+import org.opensaml.security.SecurityException;
+import org.opensaml.security.credential.BasicCredential;
+import org.opensaml.security.credential.Credential;
+import org.opensaml.security.credential.CredentialSupport;
+import org.opensaml.security.credential.UsageType;
+import org.opensaml.xmlsec.SignatureSigningParameters;
+import org.opensaml.xmlsec.SignatureSigningParametersResolver;
+import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion;
+import org.opensaml.xmlsec.crypto.XMLSigningUtil;
+import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration;
+import org.opensaml.xmlsec.signature.SignableXMLObject;
+import org.opensaml.xmlsec.signature.support.SignatureConstants;
+import org.opensaml.xmlsec.signature.support.SignatureSupport;
+import org.w3c.dom.Element;
+
+import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.Saml2X509Credential;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.util.Assert;
+import org.springframework.web.util.UriComponentsBuilder;
+import org.springframework.web.util.UriUtils;
+
+/**
+ * Utility methods for signing SAML components with OpenSAML
+ *
+ * For internal use only.
+ *
+ * @author Josh Cummings
+ */
+final class OpenSamlSigningUtils {
+
+ static String serialize(XMLObject object) {
+ try {
+ Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object);
+ Element element = marshaller.marshall(object);
+ return SerializeSupport.nodeToString(element);
+ }
+ catch (MarshallingException ex) {
+ throw new Saml2Exception(ex);
+ }
+ }
+
+ static O sign(O object, RelyingPartyRegistration relyingPartyRegistration) {
+ SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
+ try {
+ SignatureSupport.signObject(object, parameters);
+ return object;
+ }
+ catch (Exception ex) {
+ throw new Saml2Exception(ex);
+ }
+ }
+
+ static QueryParametersPartial sign(RelyingPartyRegistration registration) {
+ return new QueryParametersPartial(registration);
+ }
+
+ private static SignatureSigningParameters resolveSigningParameters(
+ RelyingPartyRegistration relyingPartyRegistration) {
+ List credentials = resolveSigningCredentials(relyingPartyRegistration);
+ List algorithms = relyingPartyRegistration.getAssertingPartyDetails().getSigningAlgorithms();
+ List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256);
+ String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS;
+ SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver();
+ CriteriaSet criteria = new CriteriaSet();
+ BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration();
+ signingConfiguration.setSigningCredentials(credentials);
+ signingConfiguration.setSignatureAlgorithms(algorithms);
+ signingConfiguration.setSignatureReferenceDigestMethods(digests);
+ signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization);
+ criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration));
+ try {
+ SignatureSigningParameters parameters = resolver.resolveSingle(criteria);
+ Assert.notNull(parameters, "Failed to resolve any signing credential");
+ return parameters;
+ }
+ catch (Exception ex) {
+ throw new Saml2Exception(ex);
+ }
+ }
+
+ private static List resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) {
+ List credentials = new ArrayList<>();
+ for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) {
+ X509Certificate certificate = x509Credential.getCertificate();
+ PrivateKey privateKey = x509Credential.getPrivateKey();
+ BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey);
+ credential.setEntityId(relyingPartyRegistration.getEntityId());
+ credential.setUsageType(UsageType.SIGNING);
+ credentials.add(credential);
+ }
+ return credentials;
+ }
+
+ private OpenSamlSigningUtils() {
+
+ }
+
+ static class QueryParametersPartial {
+
+ final RelyingPartyRegistration registration;
+
+ final Map components = new LinkedHashMap<>();
+
+ QueryParametersPartial(RelyingPartyRegistration registration) {
+ this.registration = registration;
+ }
+
+ QueryParametersPartial param(String key, String value) {
+ this.components.put(key, value);
+ return this;
+ }
+
+ Map parameters() {
+ SignatureSigningParameters parameters = resolveSigningParameters(this.registration);
+ Credential credential = parameters.getSigningCredential();
+ String algorithmUri = parameters.getSignatureAlgorithm();
+ this.components.put("SigAlg", algorithmUri);
+ UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
+ for (Map.Entry component : this.components.entrySet()) {
+ builder.queryParam(component.getKey(),
+ UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1));
+ }
+ String queryString = builder.build(true).toString().substring(1);
+ try {
+ byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri,
+ queryString.getBytes(StandardCharsets.UTF_8));
+ String b64Signature = Saml2Utils.samlEncode(rawSignature);
+ this.components.put("Signature", b64Signature);
+ }
+ catch (SecurityException ex) {
+ throw new Saml2Exception(ex);
+ }
+ return this.components;
+ }
+
+ }
+
+}
diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlVerificationUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlVerificationUtils.java
new file mode 100644
index 00000000000..f43d5ddecb1
--- /dev/null
+++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlVerificationUtils.java
@@ -0,0 +1,207 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web.authentication;
+
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Set;
+
+import jakarta.servlet.http.HttpServletRequest;
+
+import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
+import org.opensaml.core.criterion.EntityIdCriterion;
+import org.opensaml.saml.common.xml.SAMLConstants;
+import org.opensaml.saml.criterion.ProtocolCriterion;
+import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion;
+import org.opensaml.saml.saml2.core.Issuer;
+import org.opensaml.saml.saml2.core.RequestAbstractType;
+import org.opensaml.saml.saml2.core.StatusResponseType;
+import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
+import org.opensaml.security.credential.Credential;
+import org.opensaml.security.credential.CredentialResolver;
+import org.opensaml.security.credential.UsageType;
+import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion;
+import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion;
+import org.opensaml.security.credential.impl.CollectionCredentialResolver;
+import org.opensaml.security.criteria.UsageCriterion;
+import org.opensaml.security.x509.BasicX509Credential;
+import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap;
+import org.opensaml.xmlsec.signature.Signature;
+import org.opensaml.xmlsec.signature.support.SignatureTrustEngine;
+import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine;
+
+import org.springframework.security.saml2.core.Saml2Error;
+import org.springframework.security.saml2.core.Saml2ErrorCodes;
+import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
+import org.springframework.security.saml2.core.Saml2X509Credential;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.web.util.UriUtils;
+
+/**
+ * Utility methods for verifying SAML component signatures with OpenSAML
+ *
+ * For internal use only.
+ *
+ * @author Josh Cummings
+ */
+
+final class OpenSamlVerificationUtils {
+
+ static VerifierPartial verifySignature(StatusResponseType object, RelyingPartyRegistration registration) {
+ return new VerifierPartial(object, registration);
+ }
+
+ static VerifierPartial verifySignature(RequestAbstractType object, RelyingPartyRegistration registration) {
+ return new VerifierPartial(object, registration);
+ }
+
+ private OpenSamlVerificationUtils() {
+
+ }
+
+ static class VerifierPartial {
+
+ private final String id;
+
+ private final CriteriaSet criteria;
+
+ private final SignatureTrustEngine trustEngine;
+
+ VerifierPartial(StatusResponseType object, RelyingPartyRegistration registration) {
+ this.id = object.getID();
+ this.criteria = verificationCriteria(object.getIssuer());
+ this.trustEngine = trustEngine(registration);
+ }
+
+ VerifierPartial(RequestAbstractType object, RelyingPartyRegistration registration) {
+ this.id = object.getID();
+ this.criteria = verificationCriteria(object.getIssuer());
+ this.trustEngine = trustEngine(registration);
+ }
+
+ Saml2ResponseValidatorResult redirect(HttpServletRequest request, String objectParameterName) {
+ RedirectSignature signature = new RedirectSignature(request, objectParameterName);
+ if (signature.getAlgorithm() == null) {
+ return Saml2ResponseValidatorResult.failure(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
+ "Missing signature algorithm for object [" + this.id + "]"));
+ }
+ if (!signature.hasSignature()) {
+ return Saml2ResponseValidatorResult.failure(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
+ "Missing signature for object [" + this.id + "]"));
+ }
+ Collection errors = new ArrayList<>();
+ String algorithmUri = signature.getAlgorithm();
+ try {
+ if (!this.trustEngine.validate(signature.getSignature(), signature.getContent(), algorithmUri,
+ this.criteria, null)) {
+ errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
+ "Invalid signature for object [" + this.id + "]"));
+ }
+ }
+ catch (Exception ex) {
+ errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
+ "Invalid signature for object [" + this.id + "]: "));
+ }
+ return Saml2ResponseValidatorResult.failure(errors);
+ }
+
+ Saml2ResponseValidatorResult post(Signature signature) {
+ Collection errors = new ArrayList<>();
+ SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
+ try {
+ profileValidator.validate(signature);
+ }
+ catch (Exception ex) {
+ errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
+ "Invalid signature for object [" + this.id + "]: "));
+ }
+
+ try {
+ if (!this.trustEngine.validate(signature, this.criteria)) {
+ errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
+ "Invalid signature for object [" + this.id + "]"));
+ }
+ }
+ catch (Exception ex) {
+ errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
+ "Invalid signature for object [" + this.id + "]: "));
+ }
+
+ return Saml2ResponseValidatorResult.failure(errors);
+ }
+
+ private CriteriaSet verificationCriteria(Issuer issuer) {
+ CriteriaSet criteria = new CriteriaSet();
+ criteria.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())));
+ criteria.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
+ criteria.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
+ return criteria;
+ }
+
+ private SignatureTrustEngine trustEngine(RelyingPartyRegistration registration) {
+ Set credentials = new HashSet<>();
+ Collection keys = registration.getAssertingPartyDetails()
+ .getVerificationX509Credentials();
+ for (Saml2X509Credential key : keys) {
+ BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
+ cred.setUsageType(UsageType.SIGNING);
+ cred.setEntityId(registration.getAssertingPartyDetails().getEntityId());
+ credentials.add(cred);
+ }
+ CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
+ return new ExplicitKeySignatureTrustEngine(credentialsResolver,
+ DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver());
+ }
+
+ private static class RedirectSignature {
+
+ private final HttpServletRequest request;
+
+ private final String objectParameterName;
+
+ RedirectSignature(HttpServletRequest request, String objectParameterName) {
+ this.request = request;
+ this.objectParameterName = objectParameterName;
+ }
+
+ String getAlgorithm() {
+ return this.request.getParameter("SigAlg");
+ }
+
+ byte[] getContent() {
+ String query = String.format("%s=%s&SigAlg=%s", this.objectParameterName,
+ UriUtils.encode(this.request.getParameter(this.objectParameterName),
+ StandardCharsets.ISO_8859_1),
+ UriUtils.encode(getAlgorithm(), StandardCharsets.ISO_8859_1));
+ return query.getBytes(StandardCharsets.UTF_8);
+ }
+
+ byte[] getSignature() {
+ return Saml2Utils.samlDecode(this.request.getParameter("Signature"));
+ }
+
+ boolean hasSignature() {
+ return this.request.getParameter("Signature") != null;
+ }
+
+ }
+
+ }
+
+}
diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2AuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2AuthenticationRequestResolver.java
new file mode 100644
index 00000000000..80bb0ecef6d
--- /dev/null
+++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2AuthenticationRequestResolver.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web.authentication;
+
+import jakarta.servlet.http.HttpServletRequest;
+
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+
+/**
+ * A strategy for resolving a SAML 2.0 Authentication Request from the
+ * {@link HttpServletRequest}.
+ *
+ * @author Josh Cummings
+ * @since 5.7
+ */
+public interface Saml2AuthenticationRequestResolver {
+
+ T resolve(HttpServletRequest request);
+
+}
diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2Utils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2Utils.java
new file mode 100644
index 00000000000..daef78d49ae
--- /dev/null
+++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2Utils.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web.authentication;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.zip.Deflater;
+import java.util.zip.DeflaterOutputStream;
+import java.util.zip.Inflater;
+import java.util.zip.InflaterOutputStream;
+
+import org.apache.commons.codec.binary.Base64;
+
+import org.springframework.security.saml2.Saml2Exception;
+
+/**
+ * Utility methods for working with serialized SAML messages.
+ *
+ * For internal use only.
+ *
+ * @author Josh Cummings
+ */
+final class Saml2Utils {
+
+ private static Base64 BASE64 = new Base64(0, new byte[] { '\n' });
+
+ private Saml2Utils() {
+ }
+
+ static String samlEncode(byte[] b) {
+ return BASE64.encodeAsString(b);
+ }
+
+ static byte[] samlDecode(String s) {
+ return BASE64.decode(s);
+ }
+
+ static byte[] samlDeflate(String s) {
+ try {
+ ByteArrayOutputStream b = new ByteArrayOutputStream();
+ DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(Deflater.DEFLATED, true));
+ deflater.write(s.getBytes(StandardCharsets.UTF_8));
+ deflater.finish();
+ return b.toByteArray();
+ }
+ catch (IOException ex) {
+ throw new Saml2Exception("Unable to deflate string", ex);
+ }
+ }
+
+ static String samlInflate(byte[] b) {
+ try {
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
+ iout.write(b);
+ iout.finish();
+ return new String(out.toByteArray(), StandardCharsets.UTF_8);
+ }
+ catch (IOException ex) {
+ throw new Saml2Exception("Unable to inflate string", ex);
+ }
+ }
+
+}
diff --git a/saml2/saml2-service-provider/src/opensaml3Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml3AuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/opensaml3Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml3AuthenticationRequestResolver.java
new file mode 100644
index 00000000000..a6e462e1cae
--- /dev/null
+++ b/saml2/saml2-service-provider/src/opensaml3Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml3AuthenticationRequestResolver.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web.authentication;
+
+import java.time.Clock;
+import java.util.function.Consumer;
+
+import jakarta.servlet.http.HttpServletRequest;
+
+import org.joda.time.DateTime;
+import org.opensaml.saml.saml2.core.AuthnRequest;
+import org.opensaml.saml.saml2.core.LogoutRequest;
+
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
+import org.springframework.util.Assert;
+
+/**
+ * A strategy for resolving a SAML 2.0 Authentication Request from the
+ * {@link HttpServletRequest} using OpenSAML.
+ *
+ * @author Josh Cummings
+ * @since 5.7
+ * @deprecated OpenSAML 3 has reached end-of-life so this version is no longer recommended
+ */
+@Deprecated
+public final class OpenSaml3AuthenticationRequestResolver implements Saml2AuthenticationRequestResolver {
+
+ private final OpenSamlAuthenticationRequestResolver authnRequestResolver;
+
+ private Consumer contextConsumer = (parameters) -> {
+ };
+
+ private Clock clock = Clock.systemUTC();
+
+ /**
+ * Construct a {@link OpenSaml3AuthenticationRequestResolver}
+ */
+ public OpenSaml3AuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
+ this.authnRequestResolver = new OpenSamlAuthenticationRequestResolver(relyingPartyRegistrationResolver);
+ }
+
+ @Override
+ public T resolve(HttpServletRequest request) {
+ return this.authnRequestResolver.resolve(request, (registration, authnRequest) -> {
+ authnRequest.setIssueInstant(new DateTime(this.clock.millis()));
+ this.contextConsumer.accept(new AuthnRequestContext(request, registration, authnRequest));
+ });
+ }
+
+ /**
+ * Set a {@link Consumer} for modifying the OpenSAML {@link LogoutRequest}
+ * @param contextConsumer a consumer that accepts an {@link AuthnRequestContext}
+ */
+ public void setAuthnRequestCustomizer(Consumer contextConsumer) {
+ Assert.notNull(contextConsumer, "contextConsumer cannot be null");
+ this.contextConsumer = contextConsumer;
+ }
+
+ /**
+ * Use this {@link Clock} for generating the issued {@link DateTime}
+ * @param clock the {@link Clock} to use
+ */
+ public void setClock(Clock clock) {
+ Assert.notNull(clock, "clock must not be null");
+ this.clock = clock;
+ }
+
+ public static final class AuthnRequestContext {
+
+ private final HttpServletRequest request;
+
+ private final RelyingPartyRegistration registration;
+
+ private final AuthnRequest authnRequest;
+
+ public AuthnRequestContext(HttpServletRequest request, RelyingPartyRegistration registration,
+ AuthnRequest authnRequest) {
+ this.request = request;
+ this.registration = registration;
+ this.authnRequest = authnRequest;
+ }
+
+ public HttpServletRequest getRequest() {
+ return this.request;
+ }
+
+ public RelyingPartyRegistration getRelyingPartyRegistration() {
+ return this.registration;
+ }
+
+ public AuthnRequest getAuthnRequest() {
+ return this.authnRequest;
+ }
+
+ }
+
+}
diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java
new file mode 100644
index 00000000000..0bb7687458b
--- /dev/null
+++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web.authentication;
+
+import java.time.Clock;
+import java.time.Instant;
+import java.util.function.Consumer;
+
+import jakarta.servlet.http.HttpServletRequest;
+
+import org.opensaml.saml.saml2.core.AuthnRequest;
+
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
+import org.springframework.util.Assert;
+
+/**
+ * A strategy for resolving a SAML 2.0 Authentication Request from the
+ * {@link HttpServletRequest} using OpenSAML.
+ *
+ * @author Josh Cummings
+ * @since 5.7
+ */
+public final class OpenSaml4AuthenticationRequestResolver implements Saml2AuthenticationRequestResolver {
+
+ private final OpenSamlAuthenticationRequestResolver authnRequestResolver;
+
+ private Consumer contextConsumer = (parameters) -> {
+ };
+
+ private Clock clock = Clock.systemUTC();
+
+ /**
+ * Construct a {@link OpenSaml4AuthenticationRequestResolver}
+ */
+ public OpenSaml4AuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
+ this.authnRequestResolver = new OpenSamlAuthenticationRequestResolver(relyingPartyRegistrationResolver);
+ }
+
+ @Override
+ public T resolve(HttpServletRequest request) {
+ return this.authnRequestResolver.resolve(request, (registration, authnRequest) -> {
+ authnRequest.setIssueInstant(Instant.now(this.clock));
+ this.contextConsumer.accept(new AuthnRequestContext(request, registration, authnRequest));
+ });
+ }
+
+ /**
+ * Set a {@link Consumer} for modifying the OpenSAML {@link AuthnRequest}
+ * @param contextConsumer a consumer that accepts an {@link AuthnRequestContext}
+ */
+ public void setAuthnRequestCustomizer(Consumer contextConsumer) {
+ Assert.notNull(contextConsumer, "contextConsumer cannot be null");
+ this.contextConsumer = contextConsumer;
+ }
+
+ /**
+ * Use this {@link Clock} for generating the issued {@link Instant}
+ * @param clock the {@link Clock} to use
+ */
+ public void setClock(Clock clock) {
+ Assert.notNull(clock, "clock must not be null");
+ this.clock = clock;
+ }
+
+ public static final class AuthnRequestContext {
+
+ private final HttpServletRequest request;
+
+ private final RelyingPartyRegistration registration;
+
+ private final AuthnRequest authnRequest;
+
+ public AuthnRequestContext(HttpServletRequest request, RelyingPartyRegistration registration,
+ AuthnRequest authnRequest) {
+ this.request = request;
+ this.registration = registration;
+ this.authnRequest = authnRequest;
+ }
+
+ public HttpServletRequest getRequest() {
+ return this.request;
+ }
+
+ public RelyingPartyRegistration getRelyingPartyRegistration() {
+ return this.registration;
+ }
+
+ public AuthnRequest getAuthnRequest() {
+ return this.authnRequest;
+ }
+
+ }
+
+}
diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java
index 7d109e327f3..0e869429a0c 100644
--- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java
+++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java
@@ -20,6 +20,9 @@
import java.nio.charset.StandardCharsets;
import jakarta.servlet.ServletException;
+import jakarta.servlet.ServletRequest;
+import jakarta.servlet.ServletResponse;
+import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -43,6 +46,7 @@
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.web.util.HtmlUtils;
@@ -69,6 +73,9 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
private Saml2AuthenticationRequestContextResolver resolver = mock(Saml2AuthenticationRequestContextResolver.class);
+ private Saml2AuthenticationRequestResolver authenticationRequestResolver = mock(
+ Saml2AuthenticationRequestResolver.class);
+
private Saml2AuthenticationRequestRepository authenticationRequestRepository = mock(
Saml2AuthenticationRequestRepository.class);
@@ -86,7 +93,12 @@ public void setup() {
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.request.setPathInfo("/saml2/authenticate/registration-id");
- this.filterChain = new MockFilterChain();
+ this.filterChain = new MockFilterChain() {
+ @Override
+ public void doFilter(ServletRequest request, ServletResponse response) {
+ ((HttpServletResponse) response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
+ }
+ };
this.rpBuilder = RelyingPartyRegistration.withRegistrationId("registration-id")
.providerDetails((c) -> c.entityId("idp-entity-id")).providerDetails((c) -> c.webSsoUrl(IDP_SSO_URL))
.assertionConsumerServiceUrlTemplate("template")
@@ -114,6 +126,12 @@ private static Saml2RedirectAuthenticationRequest.Builder redirectAuthentication
.authenticationRequestUri(IDP_SSO_URL);
}
+ private static Saml2RedirectAuthenticationRequest.Builder redirectAuthenticationRequest(
+ RelyingPartyRegistration registration) {
+ return Saml2RedirectAuthenticationRequest.withRelyingPartyRegistration(registration).samlRequest("request")
+ .authenticationRequestUri(IDP_SSO_URL);
+ }
+
private static Saml2PostAuthenticationRequest.Builder postAuthenticationRequest(
Saml2AuthenticationRequestContext context) {
return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context).samlRequest("request")
@@ -287,4 +305,15 @@ public void doFilterWhenPathStartsWithRegistrationIdThenPosts() throws Exception
verify(this.repository).findByRegistrationId("registration-id");
}
+ @Test
+ public void doFilterWhenCustomAuthenticationRequestResolverThenUses() throws Exception {
+ RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
+ Saml2RedirectAuthenticationRequest authenticationRequest = redirectAuthenticationRequest(registration).build();
+ Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(
+ this.authenticationRequestResolver);
+ given(this.authenticationRequestResolver.resolve(any())).willReturn(authenticationRequest);
+ filter.doFilterInternal(this.request, this.response, this.filterChain);
+ verify(this.authenticationRequestResolver).resolve(any());
+ }
+
}
diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java
new file mode 100644
index 00000000000..6902c234a76
--- /dev/null
+++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java
@@ -0,0 +1,169 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web.authentication;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.opensaml.xmlsec.signature.support.SignatureConstants;
+
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.Saml2X509Credential;
+import org.springframework.security.saml2.core.TestSaml2X509Credentials;
+import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
+import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
+import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+
+/**
+ * Tests for {@link OpenSamlAuthenticationRequestResolver}
+ */
+public class OpenSamlAuthenticationRequestResolverTests {
+
+ private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
+
+ @Before
+ public void setUp() {
+ this.relyingPartyRegistrationBuilder = TestRelyingPartyRegistrations.relyingPartyRegistration();
+ }
+
+ @Test
+ public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects() {
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ request.setPathInfo("/saml2/authenticate/registration-id");
+ RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.build();
+ OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
+ Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
+ assertThat(authnRequest.getAssertionConsumerServiceURL())
+ .isEqualTo(registration.getAssertionConsumerServiceLocation());
+ assertThat(authnRequest.getProtocolBinding())
+ .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn());
+ assertThat(authnRequest.getDestination())
+ .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation());
+ assertThat(authnRequest.getIssuer().getValue()).isEqualTo(registration.getEntityId());
+ });
+ assertThat(result.getSamlRequest()).isNotEmpty();
+ assertThat(result.getRelayState()).isNotNull();
+ assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
+ assertThat(result.getSignature()).isNotEmpty();
+ assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
+ }
+
+ @Test
+ public void resolveAuthenticationRequestWhenUnsignedRedirectThenRedirectsAndNoSignature() {
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ request.setPathInfo("/saml2/authenticate/registration-id");
+ RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
+ .assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(false)).build();
+ OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
+ Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
+ assertThat(authnRequest.getAssertionConsumerServiceURL())
+ .isEqualTo(registration.getAssertionConsumerServiceLocation());
+ assertThat(authnRequest.getProtocolBinding())
+ .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn());
+ assertThat(authnRequest.getDestination())
+ .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation());
+ assertThat(authnRequest.getIssuer().getValue()).isEqualTo(registration.getEntityId());
+ });
+ assertThat(result.getSamlRequest()).isNotEmpty();
+ assertThat(result.getRelayState()).isNotNull();
+ assertThat(result.getSigAlg()).isNull();
+ assertThat(result.getSignature()).isNull();
+ assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
+ }
+
+ @Test
+ public void resolveAuthenticationRequestWhenSignedThenCredentialIsRequired() {
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ request.setPathInfo("/saml2/authenticate/registration-id");
+ Saml2X509Credential credential = TestSaml2X509Credentials.relyingPartyVerifyingCredential();
+ RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials()
+ .assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build();
+ OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
+ assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> resolver.resolve(request, null));
+ }
+
+ @Test
+ public void resolveAuthenticationRequestWhenUnsignedPostThenOnlyPosts() {
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ request.setPathInfo("/saml2/authenticate/registration-id");
+ RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.assertingPartyDetails(
+ (party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST).wantAuthnRequestsSigned(false))
+ .build();
+ OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
+ Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
+ assertThat(authnRequest.getAssertionConsumerServiceURL())
+ .isEqualTo(registration.getAssertionConsumerServiceLocation());
+ assertThat(authnRequest.getProtocolBinding())
+ .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn());
+ assertThat(authnRequest.getDestination())
+ .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation());
+ assertThat(authnRequest.getIssuer().getValue()).isEqualTo(registration.getEntityId());
+ });
+ assertThat(result.getSamlRequest()).isNotEmpty();
+ assertThat(result.getRelayState()).isNotNull();
+ assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST);
+ assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()))).doesNotContain("Signature");
+ }
+
+ @Test
+ public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts() {
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ request.setPathInfo("/saml2/authenticate/registration-id");
+ RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
+ .assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)).build();
+ OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
+ Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
+ assertThat(authnRequest.getAssertionConsumerServiceURL())
+ .isEqualTo(registration.getAssertionConsumerServiceLocation());
+ assertThat(authnRequest.getProtocolBinding())
+ .isEqualTo(registration.getAssertionConsumerServiceBinding().getUrn());
+ assertThat(authnRequest.getDestination())
+ .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation());
+ assertThat(authnRequest.getIssuer().getValue()).isEqualTo(registration.getEntityId());
+ });
+ assertThat(result.getSamlRequest()).isNotEmpty();
+ assertThat(result.getRelayState()).isNotNull();
+ assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST);
+ assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()))).contains("Signature");
+ }
+
+ @Test
+ public void resolveAuthenticationRequestWhenSHA1SignRequestThenSigns() {
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ request.setPathInfo("/saml2/authenticate/registration-id");
+ RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.assertingPartyDetails(
+ (party) -> party.signingAlgorithms((algs) -> algs.add(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1)))
+ .build();
+ OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
+ Saml2RedirectAuthenticationRequest result = resolver.resolve(request, null);
+ assertThat(result.getSamlRequest()).isNotEmpty();
+ assertThat(result.getRelayState()).isNotNull();
+ assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1);
+ assertThat(result.getSignature()).isNotNull();
+ assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
+ }
+
+ private OpenSamlAuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistration registration) {
+ return new OpenSamlAuthenticationRequestResolver((request, id) -> registration);
+ }
+
+}