|
16 | 16 |
|
17 | 17 | package org.springframework.security.config.annotation.web.configurers.saml2;
|
18 | 18 |
|
| 19 | +import java.util.ArrayList; |
19 | 20 | import java.util.LinkedHashMap;
|
| 21 | +import java.util.List; |
20 | 22 | import java.util.Map;
|
21 | 23 |
|
| 24 | +import jakarta.servlet.http.HttpServletRequest; |
| 25 | + |
22 | 26 | import org.springframework.beans.factory.NoSuchBeanDefinitionException;
|
23 | 27 | import org.springframework.context.ApplicationContext;
|
24 | 28 | import org.springframework.security.authentication.AuthenticationManager;
|
|
33 | 37 | import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
|
34 | 38 | import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
35 | 39 | import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
| 40 | +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrations; |
36 | 41 | import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository;
|
37 | 42 | import org.springframework.security.saml2.provider.service.web.OpenSamlAuthenticationTokenConverter;
|
38 | 43 | import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
|
|
50 | 55 | import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
51 | 56 | import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
|
52 | 57 | import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
| 58 | +import org.springframework.security.web.util.matcher.ParameterRequestMatcher; |
53 | 59 | import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher;
|
54 | 60 | import org.springframework.security.web.util.matcher.RequestMatcher;
|
55 | 61 | import org.springframework.security.web.util.matcher.RequestMatchers;
|
@@ -111,7 +117,13 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
|
111 | 117 |
|
112 | 118 | private String loginPage;
|
113 | 119 |
|
114 |
| - private String authenticationRequestUri = Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI; |
| 120 | + private String authenticationRequestUri = "/saml2/authenticate"; |
| 121 | + |
| 122 | + private String[] authenticationRequestParams = { "registrationId={registrationId}" }; |
| 123 | + |
| 124 | + private RequestMatcher authenticationRequestMatcher = RequestMatchers.anyOf( |
| 125 | + new AntPathRequestMatcher(Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI), |
| 126 | + new AntPathQueryRequestMatcher(this.authenticationRequestUri, this.authenticationRequestParams)); |
115 | 127 |
|
116 | 128 | private Saml2AuthenticationRequestResolver authenticationRequestResolver;
|
117 | 129 |
|
@@ -196,11 +208,31 @@ public Saml2LoginConfigurer<B> authenticationRequestResolver(
|
196 | 208 | * Request
|
197 | 209 | * @return the {@link Saml2LoginConfigurer} for further configuration
|
198 | 210 | * @since 6.0
|
| 211 | + * @deprecated Use {@link #authenticationRequestUriQuery} instead |
199 | 212 | */
|
200 | 213 | public Saml2LoginConfigurer<B> authenticationRequestUri(String authenticationRequestUri) {
|
201 |
| - Assert.state(authenticationRequestUri.contains("{registrationId}"), |
202 |
| - "authenticationRequestUri must contain {registrationId} path variable"); |
203 |
| - this.authenticationRequestUri = authenticationRequestUri; |
| 214 | + return authenticationRequestUriQuery(authenticationRequestUri); |
| 215 | + } |
| 216 | + |
| 217 | + /** |
| 218 | + * Customize the URL that the SAML Authentication Request will be sent to. This method |
| 219 | + * also supports query parameters like so: <pre> |
| 220 | + * authenticationRequestUriQuery("/saml/authenticate?registrationId={registrationId}") |
| 221 | + * </pre> {@link RelyingPartyRegistrations} |
| 222 | + * @param authenticationRequestUriQuery the URI and query to use for the SAML 2.0 |
| 223 | + * Authentication Request |
| 224 | + * @return the {@link Saml2LoginConfigurer} for further configuration |
| 225 | + * @since 6.0 |
| 226 | + */ |
| 227 | + public Saml2LoginConfigurer<B> authenticationRequestUriQuery(String authenticationRequestUriQuery) { |
| 228 | + Assert.state(authenticationRequestUriQuery.contains("{registrationId}"), |
| 229 | + "authenticationRequestUri must contain {registrationId} path variable or query value"); |
| 230 | + String[] parts = authenticationRequestUriQuery.split("[?&]"); |
| 231 | + this.authenticationRequestUri = parts[0]; |
| 232 | + this.authenticationRequestParams = new String[parts.length - 1]; |
| 233 | + System.arraycopy(parts, 1, this.authenticationRequestParams, 0, parts.length - 1); |
| 234 | + this.authenticationRequestMatcher = new AntPathQueryRequestMatcher(this.authenticationRequestUri, |
| 235 | + this.authenticationRequestParams); |
204 | 236 | return this;
|
205 | 237 | }
|
206 | 238 |
|
@@ -255,7 +287,7 @@ public void init(B http) throws Exception {
|
255 | 287 | }
|
256 | 288 | else {
|
257 | 289 | Map<String, String> providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri,
|
258 |
| - this.relyingPartyRegistrationRepository); |
| 290 | + this.authenticationRequestParams, this.relyingPartyRegistrationRepository); |
259 | 291 | boolean singleProvider = providerUrlMap.size() == 1;
|
260 | 292 | if (singleProvider) {
|
261 | 293 | // Setup auto-redirect to provider login page
|
@@ -336,8 +368,7 @@ private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver(B ht
|
336 | 368 | }
|
337 | 369 | OpenSaml4AuthenticationRequestResolver openSaml4AuthenticationRequestResolver = new OpenSaml4AuthenticationRequestResolver(
|
338 | 370 | relyingPartyRegistrationRepository(http));
|
339 |
| - openSaml4AuthenticationRequestResolver |
340 |
| - .setRequestMatcher(new AntPathRequestMatcher(this.authenticationRequestUri)); |
| 371 | + openSaml4AuthenticationRequestResolver.setRequestMatcher(this.authenticationRequestMatcher); |
341 | 372 | return openSaml4AuthenticationRequestResolver;
|
342 | 373 | }
|
343 | 374 |
|
@@ -382,20 +413,28 @@ private void initDefaultLoginFilter(B http) {
|
382 | 413 | return;
|
383 | 414 | }
|
384 | 415 | loginPageGeneratingFilter.setSaml2LoginEnabled(true);
|
385 |
| - loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName( |
386 |
| - this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository)); |
| 416 | + loginPageGeneratingFilter |
| 417 | + .setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap(this.authenticationRequestUri, |
| 418 | + this.authenticationRequestParams, this.relyingPartyRegistrationRepository)); |
387 | 419 | loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage());
|
388 | 420 | loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl());
|
389 | 421 | }
|
390 | 422 |
|
391 | 423 | @SuppressWarnings("unchecked")
|
392 |
| - private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl, |
| 424 | + private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl, String[] authRequestQueryParams, |
393 | 425 | RelyingPartyRegistrationRepository idpRepo) {
|
394 | 426 | Map<String, String> idps = new LinkedHashMap<>();
|
395 | 427 | if (idpRepo instanceof Iterable) {
|
396 | 428 | Iterable<RelyingPartyRegistration> repo = (Iterable<RelyingPartyRegistration>) idpRepo;
|
397 |
| - repo.forEach((p) -> idps.put(authRequestPrefixUrl.replace("{registrationId}", p.getRegistrationId()), |
398 |
| - p.getRegistrationId())); |
| 429 | + StringBuilder authRequestQuery = new StringBuilder("?"); |
| 430 | + for (String authRequestQueryParam : authRequestQueryParams) { |
| 431 | + authRequestQuery.append(authRequestQueryParam + "&"); |
| 432 | + } |
| 433 | + authRequestQuery.deleteCharAt(authRequestQuery.length() - 1); |
| 434 | + String authenticationRequestUriQuery = authRequestPrefixUrl + authRequestQuery; |
| 435 | + repo.forEach( |
| 436 | + (p) -> idps.put(authenticationRequestUriQuery.replace("{registrationId}", p.getRegistrationId()), |
| 437 | + p.getRegistrationId())); |
399 | 438 | }
|
400 | 439 | return idps;
|
401 | 440 | }
|
@@ -437,4 +476,35 @@ private <C> void setSharedObject(B http, Class<C> clazz, C object) {
|
437 | 476 | }
|
438 | 477 | }
|
439 | 478 |
|
| 479 | + static class AntPathQueryRequestMatcher implements RequestMatcher { |
| 480 | + |
| 481 | + private final RequestMatcher matcher; |
| 482 | + |
| 483 | + AntPathQueryRequestMatcher(String path, String... params) { |
| 484 | + List<RequestMatcher> matchers = new ArrayList<>(); |
| 485 | + matchers.add(new AntPathRequestMatcher(path)); |
| 486 | + for (String param : params) { |
| 487 | + String[] parts = param.split("="); |
| 488 | + if (parts.length == 1) { |
| 489 | + matchers.add(new ParameterRequestMatcher(parts[0])); |
| 490 | + } |
| 491 | + else { |
| 492 | + matchers.add(new ParameterRequestMatcher(parts[0], parts[1])); |
| 493 | + } |
| 494 | + } |
| 495 | + this.matcher = new AndRequestMatcher(matchers); |
| 496 | + } |
| 497 | + |
| 498 | + @Override |
| 499 | + public boolean matches(HttpServletRequest request) { |
| 500 | + return matcher(request).isMatch(); |
| 501 | + } |
| 502 | + |
| 503 | + @Override |
| 504 | + public MatchResult matcher(HttpServletRequest request) { |
| 505 | + return this.matcher.matcher(request); |
| 506 | + } |
| 507 | + |
| 508 | + } |
| 509 | + |
440 | 510 | }
|
0 commit comments