Skip to content

Commit 0df5844

Browse files
jzheauxAyush Kohli
authored and
Ayush Kohli
committed
Add RelyingPartyRegistrationResolver
Closes spring-projectsgh-9486
1 parent fe9eb5c commit 0df5844

File tree

7 files changed

+172
-33
lines changed

7 files changed

+172
-33
lines changed

docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -555,19 +555,24 @@ There are a number of reasons you may want to customize. Among them:
555555
* You may know that you will never be a multi-tenant application and so want to have a simpler URL scheme
556556
* You may identify tenants in a way other than by the URI path
557557

558-
To customize the way that a `RelyingPartyRegistration` is resolved, you can configure a custom `Converter<HttpServletRequest, RelyingPartyRegistration>`.
558+
To customize the way that a `RelyingPartyRegistration` is resolved, you can configure a custom `RelyingPartyRegistrationResolver`.
559559
The default looks up the registration id from the URI's last path element and looks it up in your `RelyingPartyRegistrationRepository`.
560560

561561
You can provide a simpler resolver that, for example, always returns the same relying party:
562562

563563
[source,java]
564564
----
565-
public class SingleRelyingPartyRegistrationResolver
566-
implements Converter<HttpServletRequest, RelyingPartyRegistration> {
565+
public class SingleRelyingPartyRegistrationResolver implements RelyingPartyRegistrationResolver {
566+
567+
private final RelyingPartyRegistrationResolver delegate;
568+
569+
public SingleRelyingPartyRegistrationResolver(RelyingPartyRegistrationRepository registrations) {
570+
this.delegate = new DefaultRelyingPartyRegistrationResolver(registrations);
571+
}
567572
568573
@Override
569-
public RelyingPartyRegistration convert(HttpServletRequest request) {
570-
return this.relyingParty;
574+
public RelyingPartyRegistration resolve(HttpServletRequest request, String registrationId) {
575+
return this.delegate.resolve(request, "single");
571576
}
572577
}
573578
----
@@ -1015,7 +1020,7 @@ You can publish a metadata endpoint by adding the `Saml2MetadataFilter` to the f
10151020

10161021
[source,java]
10171022
----
1018-
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver =
1023+
DefaultRelyingPartyRegistrationResolver relyingPartyRegistrationResolver =
10191024
new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository);
10201025
Saml2MetadataFilter filter = new Saml2MetadataFilter(
10211026
relyingPartyRegistrationResolver,
@@ -1035,11 +1040,9 @@ You can change this by calling the `setRequestMatcher` method on the filter:
10351040

10361041
[source,java]
10371042
----
1038-
filter.setRequestMatcher(new AntPathRequestMatcher("/saml2/metadata/{registrationId}", "GET"));
1043+
filter.setRequestMatcher(new AntPathRequestMatcher("/saml2/{registrationId}/metadata", "GET"));
10391044
----
10401045

1041-
ensuring that the `registrationId` hint is at the end of the path.
1042-
10431046
Or, if you have registered a custom relying party registration resolver in the constructor, then you can specify a path without a `registrationId` hint, like so:
10441047

10451048
[source,java]

saml2/saml2-service-provider/core/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,28 +42,42 @@
4242
* @since 5.4
4343
*/
4444
public final class DefaultRelyingPartyRegistrationResolver
45-
implements Converter<HttpServletRequest, RelyingPartyRegistration> {
45+
implements Converter<HttpServletRequest, RelyingPartyRegistration>, RelyingPartyRegistrationResolver {
4646

4747
private static final char PATH_DELIMITER = '/';
4848

4949
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
5050

51-
private final Converter<HttpServletRequest, String> registrationIdResolver = new RegistrationIdResolver();
51+
private final RequestMatcher registrationRequestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
5252

5353
public DefaultRelyingPartyRegistrationResolver(
5454
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
5555
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
5656
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
5757
}
5858

59+
/**
60+
* {@inheritDoc}
61+
*/
5962
@Override
6063
public RelyingPartyRegistration convert(HttpServletRequest request) {
61-
String registrationId = this.registrationIdResolver.convert(request);
62-
if (registrationId == null) {
64+
return resolve(request, null);
65+
}
66+
67+
/**
68+
* {@inheritDoc}
69+
*/
70+
@Override
71+
public RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId) {
72+
if (relyingPartyRegistrationId == null) {
73+
relyingPartyRegistrationId = this.registrationRequestMatcher.matcher(request).getVariables()
74+
.get("registrationId");
75+
}
76+
if (relyingPartyRegistrationId == null) {
6377
return null;
6478
}
6579
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository
66-
.findByRegistrationId(registrationId);
80+
.findByRegistrationId(relyingPartyRegistrationId);
6781
if (relyingPartyRegistration == null) {
6882
return null;
6983
}
@@ -111,16 +125,4 @@ private static String getApplicationUri(HttpServletRequest request) {
111125
return uriComponents.toUriString();
112126
}
113127

114-
private static class RegistrationIdResolver implements Converter<HttpServletRequest, String> {
115-
116-
private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
117-
118-
@Override
119-
public String convert(HttpServletRequest request) {
120-
RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
121-
return result.getVariables().get("registrationId");
122-
}
123-
124-
}
125-
126128
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright 2002-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.web;
18+
19+
import javax.servlet.http.HttpServletRequest;
20+
21+
import org.springframework.core.convert.converter.Converter;
22+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
23+
24+
/**
25+
* A contract for resolving a {@link RelyingPartyRegistration} from the HTTP request
26+
*
27+
* @author Josh Cummings
28+
* @since 5.5
29+
*/
30+
public interface RelyingPartyRegistrationResolver extends Converter<HttpServletRequest, RelyingPartyRegistration> {
31+
32+
@Override
33+
default RelyingPartyRegistration convert(HttpServletRequest request) {
34+
return resolve(request, null);
35+
}
36+
37+
/**
38+
* Resolve a {@link RelyingPartyRegistration} from the HTTP request, using the
39+
* {@code relyingPartyRegistrationId}, if it is provided
40+
* @param request the HTTP request
41+
* @param relyingPartyRegistrationId the {@link RelyingPartyRegistration} identifier
42+
* @return the resolved {@link RelyingPartyRegistration}
43+
*/
44+
RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId);
45+
46+
}

saml2/saml2-service-provider/core/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
4646

4747
public static final String DEFAULT_METADATA_FILE_NAME = "saml-{registrationId}-metadata.xml";
4848

49-
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter;
49+
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
5050

5151
private final Saml2MetadataResolver saml2MetadataResolver;
5252

@@ -55,11 +55,15 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
5555
private RequestMatcher requestMatcher = new AntPathRequestMatcher(
5656
"/saml2/service-provider-metadata/{registrationId}");
5757

58-
public Saml2MetadataFilter(
59-
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter,
58+
public Saml2MetadataFilter(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver,
6059
Saml2MetadataResolver saml2MetadataResolver) {
6160

62-
this.relyingPartyRegistrationConverter = relyingPartyRegistrationConverter;
61+
if (relyingPartyRegistrationResolver instanceof RelyingPartyRegistrationResolver) {
62+
this.relyingPartyRegistrationResolver = (RelyingPartyRegistrationResolver) relyingPartyRegistrationResolver;
63+
}
64+
else {
65+
this.relyingPartyRegistrationResolver = (request, id) -> relyingPartyRegistrationResolver.convert(request);
66+
}
6367
this.saml2MetadataResolver = saml2MetadataResolver;
6468
}
6569

@@ -71,14 +75,15 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
7175
chain.doFilter(request, response);
7276
return;
7377
}
74-
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationConverter.convert(request);
78+
String registrationId = matcher.getVariables().get("registrationId");
79+
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
80+
registrationId);
7581
if (relyingPartyRegistration == null) {
7682
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
7783
return;
7884
}
7985
String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration);
80-
String registrationId = relyingPartyRegistration.getRegistrationId();
81-
writeMetadataToResponse(response, registrationId, metadata);
86+
writeMetadataToResponse(response, relyingPartyRegistration.getRegistrationId(), metadata);
8287
}
8388

8489
private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata)

saml2/saml2-service-provider/core/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,26 @@
2222
import org.junit.Before;
2323
import org.junit.Test;
2424

25+
import org.springframework.mock.web.MockFilterChain;
2526
import org.springframework.mock.web.MockHttpServletRequest;
2627
import org.springframework.mock.web.MockHttpServletResponse;
28+
import org.springframework.security.authentication.AuthenticationManager;
29+
import org.springframework.security.authentication.TestingAuthenticationToken;
30+
import org.springframework.security.core.Authentication;
2731
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
32+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
2833
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
34+
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
35+
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
36+
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
37+
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
38+
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
39+
import org.springframework.security.web.util.matcher.RequestMatcher;
2940

3041
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
3142
import static org.mockito.BDDMockito.given;
3243
import static org.mockito.Mockito.mock;
44+
import static org.mockito.Mockito.verify;
3345

3446
public class Saml2WebSsoAuthenticationFilterTests {
3547

@@ -41,6 +53,8 @@ public class Saml2WebSsoAuthenticationFilterTests {
4153

4254
private HttpServletResponse response = new MockHttpServletResponse();
4355

56+
private AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
57+
4458
@Before
4559
public void setup() {
4660
this.filter = new Saml2WebSsoAuthenticationFilter(this.repository);
@@ -84,4 +98,26 @@ public void attemptAuthenticationWhenRegistrationIdDoesNotExistThenThrowsExcepti
8498
.withMessage("No relying party registration found");
8599
}
86100

101+
@Test
102+
public void doFilterWhenPathStartsWithRegistrationIdThenAuthenticates() throws Exception {
103+
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
104+
Authentication authentication = new TestingAuthenticationToken("user", "password");
105+
given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
106+
given(this.authenticationManager.authenticate(authentication)).willReturn(authentication);
107+
String loginProcessingUrl = "/{registrationId}/login/saml2/sso";
108+
RequestMatcher matcher = new AntPathRequestMatcher(loginProcessingUrl);
109+
DefaultRelyingPartyRegistrationResolver delegate = new DefaultRelyingPartyRegistrationResolver(this.repository);
110+
RelyingPartyRegistrationResolver resolver = (request, id) -> {
111+
String registrationId = matcher.matcher(request).getVariables().get("registrationId");
112+
return delegate.resolve(request, registrationId);
113+
};
114+
Saml2AuthenticationTokenConverter authenticationConverter = new Saml2AuthenticationTokenConverter(resolver);
115+
this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, loginProcessingUrl);
116+
this.filter.setAuthenticationManager(this.authenticationManager);
117+
this.request.setPathInfo("/registration-id/login/saml2/sso");
118+
this.request.setParameter("SAMLResponse", "response");
119+
this.filter.doFilter(this.request, this.response, new MockFilterChain());
120+
verify(this.repository).findByRegistrationId("registration-id");
121+
}
122+
87123
}

saml2/saml2-service-provider/core/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
3737
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
3838
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
39+
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
40+
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
41+
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
42+
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
3943
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
44+
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
45+
import org.springframework.security.web.util.matcher.RequestMatcher;
4046
import org.springframework.web.util.HtmlUtils;
4147
import org.springframework.web.util.UriUtils;
4248

@@ -216,4 +222,29 @@ public void doFilterWhenRelyingPartyRegistrationNotFoundThenUnauthorized() throw
216222
assertThat(this.response.getStatus()).isEqualTo(401);
217223
}
218224

225+
@Test
226+
public void doFilterWhenPathStartsWithRegistrationIdThenPosts() throws Exception {
227+
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
228+
.assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)).build();
229+
RequestMatcher matcher = new AntPathRequestMatcher("/{registrationId}/saml2/authenticate");
230+
DefaultRelyingPartyRegistrationResolver delegate = new DefaultRelyingPartyRegistrationResolver(this.repository);
231+
RelyingPartyRegistrationResolver resolver = (request, id) -> {
232+
String registrationId = matcher.matcher(request).getVariables().get("registrationId");
233+
return delegate.resolve(request, registrationId);
234+
};
235+
Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver(
236+
resolver);
237+
Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class);
238+
given(authenticationRequest.getAuthenticationRequestUri()).willReturn("uri");
239+
given(authenticationRequest.getRelayState()).willReturn("relay");
240+
given(authenticationRequest.getSamlRequest()).willReturn("saml");
241+
given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
242+
given(this.factory.createPostAuthenticationRequest(any())).willReturn(authenticationRequest);
243+
this.filter = new Saml2WebSsoAuthenticationRequestFilter(authenticationRequestContextResolver, this.factory);
244+
this.filter.setRedirectMatcher(matcher);
245+
this.request.setPathInfo("/registration-id/saml2/authenticate");
246+
this.filter.doFilter(this.request, this.response, new MockFilterChain());
247+
verify(this.repository).findByRegistrationId("registration-id");
248+
}
249+
219250
}

saml2/saml2-service-provider/core/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.junit.Test;
2626

2727
import org.springframework.http.HttpHeaders;
28+
import org.springframework.mock.web.MockFilterChain;
2829
import org.springframework.mock.web.MockHttpServletRequest;
2930
import org.springframework.mock.web.MockHttpServletResponse;
3031
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
@@ -37,6 +38,7 @@
3738
import static org.assertj.core.api.Assertions.assertThat;
3839
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
3940
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
41+
import static org.mockito.ArgumentMatchers.any;
4042
import static org.mockito.BDDMockito.given;
4143
import static org.mockito.Mockito.mock;
4244
import static org.mockito.Mockito.verify;
@@ -136,6 +138,20 @@ public void doFilterWhenSetMetadataFilenameThenUses() throws Exception {
136138
.isEqualTo("attachment; filename=\"%s\"; filename*=UTF-8''%s", fileName, encodedFileName);
137139
}
138140

141+
@Test
142+
public void doFilterWhenPathStartsWithRegistrationIdThenServesMetadata() throws Exception {
143+
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
144+
given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
145+
given(this.resolver.resolve(any())).willReturn("metadata");
146+
DefaultRelyingPartyRegistrationResolver resolver = new DefaultRelyingPartyRegistrationResolver(
147+
(id) -> this.repository.findByRegistrationId("registration-id"));
148+
this.filter = new Saml2MetadataFilter(resolver, this.resolver);
149+
this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata"));
150+
this.request.setPathInfo("/metadata");
151+
this.filter.doFilter(this.request, this.response, new MockFilterChain());
152+
verify(this.repository).findByRegistrationId("registration-id");
153+
}
154+
139155
@Test
140156
public void setRequestMatcherWhenNullThenIllegalArgument() {
141157
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestMatcher(null));

0 commit comments

Comments
 (0)