|
16 | 16 |
|
17 | 17 | package org.springframework.security.config.annotation.web.configurers.saml2;
|
18 | 18 |
|
| 19 | +import java.util.ArrayList; |
19 | 20 | import java.util.LinkedHashMap;
|
| 21 | +import java.util.List; |
20 | 22 | import java.util.Map;
|
21 | 23 |
|
| 24 | +import jakarta.servlet.http.HttpServletRequest; |
| 25 | + |
22 | 26 | import org.springframework.beans.factory.NoSuchBeanDefinitionException;
|
23 | 27 | import org.springframework.context.ApplicationContext;
|
24 | 28 | import org.springframework.security.authentication.AuthenticationManager;
|
|
50 | 54 | import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
51 | 55 | import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
|
52 | 56 | import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
| 57 | +import org.springframework.security.web.util.matcher.ParameterRequestMatcher; |
53 | 58 | import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher;
|
54 | 59 | import org.springframework.security.web.util.matcher.RequestMatcher;
|
55 | 60 | import org.springframework.security.web.util.matcher.RequestMatchers;
|
@@ -113,6 +118,8 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
|
113 | 118 |
|
114 | 119 | private String authenticationRequestUri = Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI;
|
115 | 120 |
|
| 121 | + private String[] authenticationRequestParams = new String[0]; |
| 122 | + |
116 | 123 | private Saml2AuthenticationRequestResolver authenticationRequestResolver;
|
117 | 124 |
|
118 | 125 | private RequestMatcher loginProcessingUrl = RequestMatchers.anyOf(
|
@@ -198,12 +205,36 @@ public Saml2LoginConfigurer<B> authenticationRequestResolver(
|
198 | 205 | * @since 6.0
|
199 | 206 | */
|
200 | 207 | public Saml2LoginConfigurer<B> authenticationRequestUri(String authenticationRequestUri) {
|
201 |
| - Assert.state(authenticationRequestUri.contains("{registrationId}"), |
202 |
| - "authenticationRequestUri must contain {registrationId} path variable"); |
| 208 | + return authenticationRequestUri(authenticationRequestUri, new String[0]); |
| 209 | + } |
| 210 | + |
| 211 | + /** |
| 212 | + * Customize the URL that the SAML Authentication Request will be sent to. |
| 213 | + * @param authenticationRequestUri the URI to use for the SAML 2.0 Authentication |
| 214 | + * Request |
| 215 | + * @param authenticationRequestParams any parameters to match on, of the form |
| 216 | + * {@code name=value} |
| 217 | + * @return the {@link Saml2LoginConfigurer} for further configuration |
| 218 | + * @since 6.0 |
| 219 | + */ |
| 220 | + public Saml2LoginConfigurer<B> authenticationRequestUri(String authenticationRequestUri, |
| 221 | + String... authenticationRequestParams) { |
| 222 | + Assert.state(authenticationRequestUri.contains("{registrationId}") || anyContains(authenticationRequestParams), |
| 223 | + "authenticationRequestUri or an authenticationRequestParam must contain {registrationId} path variable"); |
203 | 224 | this.authenticationRequestUri = authenticationRequestUri;
|
| 225 | + this.authenticationRequestParams = authenticationRequestParams; |
204 | 226 | return this;
|
205 | 227 | }
|
206 | 228 |
|
| 229 | + private static boolean anyContains(String[] authenticationRequestParams) { |
| 230 | + for (String param : authenticationRequestParams) { |
| 231 | + if (param.contains("{registrationId}")) { |
| 232 | + return true; |
| 233 | + } |
| 234 | + } |
| 235 | + return false; |
| 236 | + } |
| 237 | + |
207 | 238 | /**
|
208 | 239 | * Specifies the URL to validate the credentials. If specified a custom URL, consider
|
209 | 240 | * specifying a custom {@link AuthenticationConverter} via
|
@@ -255,7 +286,7 @@ public void init(B http) throws Exception {
|
255 | 286 | }
|
256 | 287 | else {
|
257 | 288 | Map<String, String> providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri,
|
258 |
| - this.relyingPartyRegistrationRepository); |
| 289 | + this.authenticationRequestParams, this.relyingPartyRegistrationRepository); |
259 | 290 | boolean singleProvider = providerUrlMap.size() == 1;
|
260 | 291 | if (singleProvider) {
|
261 | 292 | // Setup auto-redirect to provider login page
|
@@ -336,8 +367,14 @@ private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver(B ht
|
336 | 367 | }
|
337 | 368 | OpenSaml4AuthenticationRequestResolver openSaml4AuthenticationRequestResolver = new OpenSaml4AuthenticationRequestResolver(
|
338 | 369 | relyingPartyRegistrationRepository(http));
|
339 |
| - openSaml4AuthenticationRequestResolver |
340 |
| - .setRequestMatcher(new AntPathRequestMatcher(this.authenticationRequestUri)); |
| 370 | + if (this.authenticationRequestParams.length > 0) { |
| 371 | + openSaml4AuthenticationRequestResolver.setRequestMatcher( |
| 372 | + new AntPathQueryRequestMatcher(this.authenticationRequestUri, this.authenticationRequestParams)); |
| 373 | + } |
| 374 | + else { |
| 375 | + openSaml4AuthenticationRequestResolver |
| 376 | + .setRequestMatcher(new AntPathRequestMatcher(this.authenticationRequestUri)); |
| 377 | + } |
341 | 378 | return openSaml4AuthenticationRequestResolver;
|
342 | 379 | }
|
343 | 380 |
|
@@ -383,18 +420,24 @@ private void initDefaultLoginFilter(B http) {
|
383 | 420 | }
|
384 | 421 | loginPageGeneratingFilter.setSaml2LoginEnabled(true);
|
385 | 422 | loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(
|
386 |
| - this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository)); |
| 423 | + this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.authenticationRequestParams, this.relyingPartyRegistrationRepository)); |
387 | 424 | loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage());
|
388 | 425 | loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl());
|
389 | 426 | }
|
390 | 427 |
|
391 | 428 | @SuppressWarnings("unchecked")
|
392 |
| - private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl, |
| 429 | + private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl, String[] authRequestQueryParams, |
393 | 430 | RelyingPartyRegistrationRepository idpRepo) {
|
394 | 431 | Map<String, String> idps = new LinkedHashMap<>();
|
395 | 432 | if (idpRepo instanceof Iterable) {
|
396 | 433 | Iterable<RelyingPartyRegistration> repo = (Iterable<RelyingPartyRegistration>) idpRepo;
|
397 |
| - repo.forEach((p) -> idps.put(authRequestPrefixUrl.replace("{registrationId}", p.getRegistrationId()), |
| 434 | + StringBuilder authRequestQuery = new StringBuilder("?"); |
| 435 | + for (String authRequestQueryParam : authRequestQueryParams) { |
| 436 | + authRequestQuery.append(authRequestQueryParam + "&"); |
| 437 | + } |
| 438 | + authRequestQuery.deleteCharAt(authRequestQuery.length() - 1); |
| 439 | + String authenticationRequestUriQuery = authRequestPrefixUrl + authRequestQuery; |
| 440 | + repo.forEach((p) -> idps.put(authenticationRequestUriQuery.replace("{registrationId}", p.getRegistrationId()), |
398 | 441 | p.getRegistrationId()));
|
399 | 442 | }
|
400 | 443 | return idps;
|
@@ -437,4 +480,35 @@ private <C> void setSharedObject(B http, Class<C> clazz, C object) {
|
437 | 480 | }
|
438 | 481 | }
|
439 | 482 |
|
| 483 | + static class AntPathQueryRequestMatcher implements RequestMatcher { |
| 484 | + |
| 485 | + private final RequestMatcher matcher; |
| 486 | + |
| 487 | + AntPathQueryRequestMatcher(String path, String... params) { |
| 488 | + List<RequestMatcher> matchers = new ArrayList<>(); |
| 489 | + matchers.add(new AntPathRequestMatcher(path)); |
| 490 | + for (String param : params) { |
| 491 | + String[] parts = param.split("="); |
| 492 | + if (parts.length == 1) { |
| 493 | + matchers.add(new ParameterRequestMatcher(parts[0])); |
| 494 | + } |
| 495 | + else { |
| 496 | + matchers.add(new ParameterRequestMatcher(parts[0], parts[1])); |
| 497 | + } |
| 498 | + } |
| 499 | + this.matcher = new AndRequestMatcher(matchers); |
| 500 | + } |
| 501 | + |
| 502 | + @Override |
| 503 | + public boolean matches(HttpServletRequest request) { |
| 504 | + return matcher(request).isMatch(); |
| 505 | + } |
| 506 | + |
| 507 | + @Override |
| 508 | + public MatchResult matcher(HttpServletRequest request) { |
| 509 | + return this.matcher.matcher(request); |
| 510 | + } |
| 511 | + |
| 512 | + } |
| 513 | + |
440 | 514 | }
|
0 commit comments