Skip to content

Commit a10c2c6

Browse files
committed
Polish DefaultSaml2AuthenticationRequestContextResolver
Issue gh-8360 Issue gh-8887
1 parent 015281f commit a10c2c6

File tree

7 files changed

+75
-134
lines changed

7 files changed

+75
-134
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -35,6 +35,9 @@
3535
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
3636
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
3737
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
38+
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
39+
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
40+
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
3841
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
3942
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
4043
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -317,15 +320,16 @@ private <C> void setSharedObject(B http, Class<C> clazz, C object) {
317320

318321
private final class AuthenticationRequestEndpointConfig {
319322
private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";
323+
320324
private AuthenticationRequestEndpointConfig() {
321325
}
322326

323327
private Filter build(B http) {
324328
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
329+
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
325330

326331
return postProcess(new Saml2WebSsoAuthenticationRequestFilter(
327-
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository,
328-
authenticationRequestResolver));
332+
contextResolver, authenticationRequestResolver));
329333
}
330334

331335
private Saml2AuthenticationRequestFactory getResolver(B http) {
@@ -335,6 +339,16 @@ private Saml2AuthenticationRequestFactory getResolver(B http) {
335339
}
336340
return resolver;
337341
}
342+
343+
private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
344+
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class);
345+
if (resolver == null) {
346+
return new DefaultSaml2AuthenticationRequestContextResolver(
347+
new DefaultRelyingPartyRegistrationResolver(
348+
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository));
349+
}
350+
return resolver;
351+
}
338352
}
339353

340354
}

config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

+4-16
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,8 @@
6565
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
6666
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
6767
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
68-
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
6968
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
7069
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
71-
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
7270
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
7371
import org.springframework.security.web.FilterChainProxy;
7472
import org.springframework.security.web.context.HttpRequestResponseHolder;
@@ -87,6 +85,7 @@
8785
import static org.mockito.Mockito.mock;
8886
import static org.mockito.Mockito.verify;
8987
import static org.mockito.Mockito.when;
88+
import static org.springframework.security.config.Customizer.withDefaults;
9089
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
9190
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
9291
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@@ -161,11 +160,11 @@ public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() t
161160
Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
162161
Saml2AuthenticationRequestContextResolver resolver =
163162
CustomAuthenticationRequestContextResolver.resolver;
164-
when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)))
163+
when(resolver.resolve(any(HttpServletRequest.class)))
165164
.thenReturn(context);
166165
this.mvc.perform(get("/saml2/authenticate/registration-id"))
167166
.andExpect(status().isFound());
168-
verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
167+
verify(resolver).resolve(any(HttpServletRequest.class));
169168
}
170169

171170
@Test
@@ -276,22 +275,11 @@ static class CustomAuthenticationRequestContextResolver extends WebSecurityConfi
276275

277276
@Override
278277
protected void configure(HttpSecurity http) throws Exception {
279-
ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter> processor
280-
= new ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter>() {
281-
@Override
282-
public <O extends Saml2WebSsoAuthenticationRequestFilter> O postProcess(O filter) {
283-
filter.setAuthenticationRequestContextResolver(resolver);
284-
return filter;
285-
}
286-
};
287-
288278
http
289279
.authorizeRequests(authz -> authz
290280
.anyRequest().authenticated()
291281
)
292-
.saml2Login(saml2 -> saml2
293-
.addObjectPostProcessor(processor)
294-
);
282+
.saml2Login(withDefaults());
295283
}
296284

297285
@Bean

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java

+13-23
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
3131
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
3232
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
33+
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
3334
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
3435
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
3536
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -69,9 +70,8 @@
6970
*/
7071
public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter {
7172

72-
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
73+
private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver;
7374
private Saml2AuthenticationRequestFactory authenticationRequestFactory;
74-
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();
7575

7676
private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
7777

@@ -83,21 +83,24 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
8383
*/
8484
@Deprecated
8585
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
86-
this(relyingPartyRegistrationRepository,
86+
this(new DefaultSaml2AuthenticationRequestContextResolver(
87+
new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
8788
new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory());
8889
}
8990

9091
/**
9192
* Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters
9293
*
93-
* @param relyingPartyRegistrationRepository a repository for relying party configurations
94+
* @param authenticationRequestContextResolver a strategy for formulating a {@link Saml2AuthenticationRequestContext}
9495
* @since 5.4
9596
*/
96-
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
97+
public Saml2WebSsoAuthenticationRequestFilter(
98+
Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver,
9799
Saml2AuthenticationRequestFactory authenticationRequestFactory) {
98-
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
100+
101+
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
99102
Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
100-
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
103+
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
101104
this.authenticationRequestFactory = authenticationRequestFactory;
102105
}
103106

@@ -123,17 +126,6 @@ public void setRedirectMatcher(RequestMatcher redirectMatcher) {
123126
this.redirectMatcher = redirectMatcher;
124127
}
125128

126-
/**
127-
* Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext}
128-
*
129-
* @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use
130-
* @since 5.4
131-
*/
132-
public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) {
133-
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
134-
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
135-
}
136-
137129
/**
138130
* {@inheritDoc}
139131
*/
@@ -147,14 +139,12 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
147139
return;
148140
}
149141

150-
String registrationId = matcher.getVariables().get("registrationId");
151-
RelyingPartyRegistration relyingParty =
152-
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
153-
if (relyingParty == null) {
142+
Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request);
143+
if (context == null) {
154144
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
155145
return;
156146
}
157-
Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty);
147+
RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
158148
if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
159149
sendRedirect(response, context);
160150
} else {

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java

+15-63
Original file line numberDiff line numberDiff line change
@@ -16,45 +16,45 @@
1616

1717
package org.springframework.security.saml2.provider.service.web;
1818

19-
import java.util.HashMap;
20-
import java.util.Map;
21-
import java.util.function.Function;
2219
import javax.servlet.http.HttpServletRequest;
2320

2421
import org.apache.commons.logging.Log;
2522
import org.apache.commons.logging.LogFactory;
2623

24+
import org.springframework.core.convert.converter.Converter;
2725
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
2826
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
2927
import org.springframework.util.Assert;
30-
import org.springframework.util.StringUtils;
31-
import org.springframework.web.util.UriComponents;
32-
import org.springframework.web.util.UriComponentsBuilder;
33-
34-
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
35-
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
3628

3729
/**
3830
* The default implementation for {@link Saml2AuthenticationRequestContextResolver}
3931
* which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext}
4032
*
4133
* @author Shazin Sadakath
34+
* @author Josh Cummings
4235
* @since 5.4
4336
*/
4437
public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver {
4538

4639
private final Log logger = LogFactory.getLog(getClass());
4740

48-
private static final char PATH_DELIMITER = '/';
41+
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
42+
43+
public DefaultSaml2AuthenticationRequestContextResolver
44+
(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
45+
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
46+
}
4947

5048
/**
5149
* {@inheritDoc}
5250
*/
5351
@Override
54-
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
55-
RelyingPartyRegistration relyingParty) {
52+
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) {
5653
Assert.notNull(request, "request cannot be null");
57-
Assert.notNull(relyingParty, "relyingParty cannot be null");
54+
RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationResolver.convert(request);
55+
if (relyingParty == null) {
56+
return null;
57+
}
5858
if (this.logger.isDebugEnabled()) {
5959
this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" +
6060
relyingParty.getRegistrationId() + "]");
@@ -65,59 +65,11 @@ public Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
6565
private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(
6666
HttpServletRequest request, RelyingPartyRegistration relyingParty) {
6767

68-
String applicationUri = getApplicationUri(request);
69-
Function<String, String> resolver = templateResolver(applicationUri, relyingParty);
70-
String localSpEntityId = resolver.apply(relyingParty.getEntityId());
71-
String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceLocation());
7268
return Saml2AuthenticationRequestContext.builder()
73-
.issuer(localSpEntityId)
69+
.issuer(relyingParty.getEntityId())
7470
.relyingPartyRegistration(relyingParty)
75-
.assertionConsumerServiceUrl(assertionConsumerServiceUrl)
71+
.assertionConsumerServiceUrl(relyingParty.getAssertionConsumerServiceLocation())
7672
.relayState(request.getParameter("RelayState"))
7773
.build();
7874
}
79-
80-
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
81-
return template -> resolveUrlTemplate(template, applicationUri, relyingParty);
82-
}
83-
84-
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
85-
String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
86-
String registrationId = relyingParty.getRegistrationId();
87-
Map<String, String> uriVariables = new HashMap<>();
88-
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl)
89-
.replaceQuery(null)
90-
.fragment(null)
91-
.build();
92-
String scheme = uriComponents.getScheme();
93-
uriVariables.put("baseScheme", scheme == null ? "" : scheme);
94-
String host = uriComponents.getHost();
95-
uriVariables.put("baseHost", host == null ? "" : host);
96-
// following logic is based on HierarchicalUriComponents#toUriString()
97-
int port = uriComponents.getPort();
98-
uriVariables.put("basePort", port == -1 ? "" : ":" + port);
99-
String path = uriComponents.getPath();
100-
if (StringUtils.hasLength(path)) {
101-
if (path.charAt(0) != PATH_DELIMITER) {
102-
path = PATH_DELIMITER + path;
103-
}
104-
}
105-
uriVariables.put("basePath", path == null ? "" : path);
106-
uriVariables.put("baseUrl", uriComponents.toUriString());
107-
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
108-
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
109-
110-
return UriComponentsBuilder.fromUriString(template)
111-
.buildAndExpand(uriVariables)
112-
.toUriString();
113-
}
114-
115-
private static String getApplicationUri(HttpServletRequest request) {
116-
UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request))
117-
.replacePath(request.getContextPath())
118-
.replaceQuery(null)
119-
.fragment(null)
120-
.build();
121-
return uriComponents.toUriString();
122-
}
12375
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java

+5-7
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@
1616

1717
package org.springframework.security.saml2.provider.service.web;
1818

19-
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
20-
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
21-
2219
import javax.servlet.http.HttpServletRequest;
2320

21+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
22+
2423
/**
2524
* This {@code Saml2AuthenticationRequestContextResolver} formulates a
2625
* <a href="https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf">SAML 2.0 AuthnRequest</a> (line 1968)
2726
*
2827
* @author Shazin Sadakath
28+
* @author Josh Cummings
2929
* @since 5.4
3030
*/
3131
public interface Saml2AuthenticationRequestContextResolver {
@@ -35,9 +35,7 @@ public interface Saml2AuthenticationRequestContextResolver {
3535
*
3636
*
3737
* @param request the current request
38-
* @param relyingParty the relying party responsible for saml2 sso authentication
39-
* @return the created {@link Saml2AuthenticationRequestContext} for request/relying party combination
38+
* @return the created {@link Saml2AuthenticationRequestContext} for the request
4039
*/
41-
Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
42-
RelyingPartyRegistration relyingParty);
40+
Saml2AuthenticationRequestContext resolve(HttpServletRequest request);
4341
}

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
3131
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
3232
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
33+
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
3334
import org.springframework.web.util.HtmlUtils;
3435
import org.springframework.web.util.UriUtils;
3536

@@ -41,6 +42,7 @@
4142
import static org.mockito.Mockito.verifyNoInteractions;
4243
import static org.mockito.Mockito.when;
4344
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
45+
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
4446
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
4547

4648
public class Saml2WebSsoAuthenticationRequestFilterTests {
@@ -49,6 +51,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
4951
private Saml2WebSsoAuthenticationRequestFilter filter;
5052
private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
5153
private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class);
54+
private Saml2AuthenticationRequestContextResolver resolver =
55+
mock(Saml2AuthenticationRequestContextResolver.class);
5256
private MockHttpServletRequest request;
5357
private MockHttpServletResponse response;
5458
private MockFilterChain filterChain;
@@ -188,12 +192,14 @@ public void doFilterWhenCustomAuthenticationRequestFactoryThenUses() throws Exce
188192
when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
189193
when(authenticationRequest.getRelayState()).thenReturn("relay");
190194
when(authenticationRequest.getSamlRequest()).thenReturn("saml");
191-
when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty);
195+
when(this.resolver.resolve(this.request)).thenReturn(authenticationRequestContext()
196+
.relyingPartyRegistration(relyingParty)
197+
.build());
192198
when(this.factory.createPostAuthenticationRequest(any()))
193199
.thenReturn(authenticationRequest);
194200

195201
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
196-
(this.repository, this.factory);
202+
(this.resolver, this.factory);
197203
filter.doFilterInternal(this.request, this.response, this.filterChain);
198204
assertThat(this.response.getContentAsString())
199205
.contains("<form action=\"uri\" method=\"post\">")

0 commit comments

Comments
 (0)