Skip to content

Commit 8c0bdd5

Browse files
shazinjzheaux
authored andcommitted
Delegating Saml2AuthenticationRequestContext creation to Saml2AuthenticationRequestContextResolver
Saml2AuthenticationRequestContext creation logic is not extensible at the moment as it is provided inside of Saml2WebSsoAuthenticationRequestFilter. This change enables to custom logic to be used when creating Saml2AuthenticationRequestContext by taking the logic from the aforementioned filter to a seperate extensible API by the name Saml2AuthenticationRequestContextResolver. This provides following API contract and implementation: - Saml2AuthenticationRequestContextResolver - DefaultSaml2AuthenticationRequestContextResolver Fixes gh-8360
1 parent b9b8903 commit 8c0bdd5

File tree

5 files changed

+259
-35
lines changed

5 files changed

+259
-35
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,6 @@
1616

1717
package org.springframework.security.saml2.provider.service.servlet.filter;
1818

19-
import java.io.IOException;
20-
import java.util.function.Function;
21-
import javax.servlet.FilterChain;
22-
import javax.servlet.ServletException;
23-
import javax.servlet.http.HttpServletRequest;
24-
import javax.servlet.http.HttpServletResponse;
25-
2619
import org.springframework.http.MediaType;
2720
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
2821
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
@@ -31,6 +24,8 @@
3124
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
3225
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
3326
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
27+
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
28+
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
3429
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
3530
import org.springframework.security.web.util.matcher.RequestMatcher;
3631
import org.springframework.security.web.util.matcher.RequestMatcher.MatchResult;
@@ -41,6 +36,12 @@
4136
import org.springframework.web.util.UriComponentsBuilder;
4237
import org.springframework.web.util.UriUtils;
4338

39+
import javax.servlet.FilterChain;
40+
import javax.servlet.ServletException;
41+
import javax.servlet.http.HttpServletRequest;
42+
import javax.servlet.http.HttpServletResponse;
43+
import java.io.IOException;
44+
4445
import static java.nio.charset.StandardCharsets.ISO_8859_1;
4546

4647
/**
@@ -70,6 +71,7 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
7071

7172
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
7273
private Saml2AuthenticationRequestFactory authenticationRequestFactory;
74+
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();
7375

7476
private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
7577

@@ -121,6 +123,17 @@ public void setRedirectMatcher(RequestMatcher redirectMatcher) {
121123
this.redirectMatcher = redirectMatcher;
122124
}
123125

126+
/**
127+
* Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext}
128+
*
129+
* @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use
130+
* @since 5.4
131+
*/
132+
public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) {
133+
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
134+
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
135+
}
136+
124137
/**
125138
* {@inheritDoc}
126139
*/
@@ -141,38 +154,14 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
141154
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
142155
return;
143156
}
144-
if (this.logger.isDebugEnabled()) {
145-
this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" +
146-
relyingParty.getRegistrationId() + "]");
147-
}
148-
Saml2AuthenticationRequestContext context = createRedirectAuthenticationRequestContext(request, relyingParty);
157+
Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty);
149158
if (relyingParty.getProviderDetails().getBinding() == Saml2MessageBinding.REDIRECT) {
150159
sendRedirect(response, context);
151-
}
152-
else {
160+
} else {
153161
sendPost(response, context);
154162
}
155163
}
156164

157-
private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(
158-
HttpServletRequest request, RelyingPartyRegistration relyingParty) {
159-
160-
String applicationUri = Saml2ServletUtils.getApplicationUri(request);
161-
Function<String, String> resolver = templateResolver(applicationUri, relyingParty);
162-
String localSpEntityId = resolver.apply(relyingParty.getLocalEntityIdTemplate());
163-
String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceUrlTemplate());
164-
return Saml2AuthenticationRequestContext.builder()
165-
.issuer(localSpEntityId)
166-
.relyingPartyRegistration(relyingParty)
167-
.assertionConsumerServiceUrl(assertionConsumerServiceUrl)
168-
.relayState(request.getParameter("RelayState"))
169-
.build();
170-
}
171-
172-
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
173-
return template -> Saml2ServletUtils.resolveUrlTemplate(template, applicationUri, relyingParty);
174-
}
175-
176165
private void sendRedirect(HttpServletResponse response, Saml2AuthenticationRequestContext context)
177166
throws IOException {
178167
Saml2RedirectAuthenticationRequest authenticationRequest =
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Copyright 2002-2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.web;
18+
19+
import org.apache.commons.logging.Log;
20+
import org.apache.commons.logging.LogFactory;
21+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
22+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
23+
import org.springframework.util.Assert;
24+
import org.springframework.util.StringUtils;
25+
import org.springframework.web.util.UriComponents;
26+
import org.springframework.web.util.UriComponentsBuilder;
27+
28+
import javax.servlet.http.HttpServletRequest;
29+
import java.util.HashMap;
30+
import java.util.Map;
31+
import java.util.function.Function;
32+
33+
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
34+
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
35+
36+
/**
37+
* The default implementation for {@link Saml2AuthenticationRequestContextResolver}
38+
* which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext}
39+
*
40+
* @author Shazin Sadakath
41+
* @since 5.4
42+
*/
43+
public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver {
44+
45+
private final Log logger = LogFactory.getLog(getClass());
46+
47+
private static final char PATH_DELIMITER = '/';
48+
49+
/**
50+
* {@inheritDoc}
51+
*/
52+
@Override
53+
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
54+
RelyingPartyRegistration relyingParty) {
55+
Assert.notNull(request, "request cannot be null");
56+
Assert.notNull(relyingParty, "relyingParty cannot be null");
57+
if (this.logger.isDebugEnabled()) {
58+
this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" +
59+
relyingParty.getRegistrationId() + "]");
60+
}
61+
return createRedirectAuthenticationRequestContext(request, relyingParty);
62+
}
63+
64+
private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(
65+
HttpServletRequest request, RelyingPartyRegistration relyingParty) {
66+
67+
String applicationUri = getApplicationUri(request);
68+
Function<String, String> resolver = templateResolver(applicationUri, relyingParty);
69+
String localSpEntityId = resolver.apply(relyingParty.getLocalEntityIdTemplate());
70+
String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceUrlTemplate());
71+
return Saml2AuthenticationRequestContext.builder()
72+
.issuer(localSpEntityId)
73+
.relyingPartyRegistration(relyingParty)
74+
.assertionConsumerServiceUrl(assertionConsumerServiceUrl)
75+
.relayState(request.getParameter("RelayState"))
76+
.build();
77+
}
78+
79+
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
80+
return template -> resolveUrlTemplate(template, applicationUri, relyingParty);
81+
}
82+
83+
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
84+
if (!StringUtils.hasText(template)) {
85+
return baseUrl;
86+
}
87+
88+
String entityId = relyingParty.getProviderDetails().getEntityId();
89+
String registrationId = relyingParty.getRegistrationId();
90+
Map<String, String> uriVariables = new HashMap<>();
91+
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl)
92+
.replaceQuery(null)
93+
.fragment(null)
94+
.build();
95+
String scheme = uriComponents.getScheme();
96+
uriVariables.put("baseScheme", scheme == null ? "" : scheme);
97+
String host = uriComponents.getHost();
98+
uriVariables.put("baseHost", host == null ? "" : host);
99+
// following logic is based on HierarchicalUriComponents#toUriString()
100+
int port = uriComponents.getPort();
101+
uriVariables.put("basePort", port == -1 ? "" : ":" + port);
102+
String path = uriComponents.getPath();
103+
if (StringUtils.hasLength(path)) {
104+
if (path.charAt(0) != PATH_DELIMITER) {
105+
path = PATH_DELIMITER + path;
106+
}
107+
}
108+
uriVariables.put("basePath", path == null ? "" : path);
109+
uriVariables.put("baseUrl", uriComponents.toUriString());
110+
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
111+
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
112+
113+
return UriComponentsBuilder.fromUriString(template)
114+
.buildAndExpand(uriVariables)
115+
.toUriString();
116+
}
117+
118+
private static String getApplicationUri(HttpServletRequest request) {
119+
UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request))
120+
.replacePath(request.getContextPath())
121+
.replaceQuery(null)
122+
.fragment(null)
123+
.build();
124+
return uriComponents.toUriString();
125+
}
126+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright 2002-2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.web;
18+
19+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
20+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
21+
22+
import javax.servlet.http.HttpServletRequest;
23+
24+
/**
25+
* This {@code Saml2AuthenticationRequestContextResolver} formulates a
26+
* <a href="https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf">SAML 2.0 AuthnRequest</a> (line 1968)
27+
*
28+
* @author Shazin Sadakath
29+
* @since 5.4
30+
*/
31+
public interface Saml2AuthenticationRequestContextResolver {
32+
33+
/**
34+
* This {@code resolve} method is defined to create a {@link Saml2AuthenticationRequestContext}
35+
*
36+
*
37+
* @param request the current request
38+
* @param relyingParty the relying party responsible for saml2 sso authentication
39+
* @return the created {@link Saml2AuthenticationRequestContext} for request/relying party combination
40+
*/
41+
Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
42+
RelyingPartyRegistration relyingParty);
43+
}

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/TestSaml2SigningCredentials.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION;
3232
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING;
3333

34-
final class TestSaml2SigningCredentials {
34+
public final class TestSaml2SigningCredentials {
3535

36-
static Saml2X509Credential signingCredential() {
36+
public static Saml2X509Credential signingCredential() {
3737
return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), SIGNING, DECRYPTION);
3838
}
3939

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright 2002-2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.web;
18+
19+
import org.junit.Before;
20+
import org.junit.Test;
21+
import org.springframework.mock.web.MockHttpServletRequest;
22+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
23+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
24+
25+
import static org.springframework.security.saml2.provider.service.servlet.filter.TestSaml2SigningCredentials.signingCredential;
26+
import static org.assertj.core.api.Assertions.*;
27+
28+
public class DefaultSaml2AuthenticationRequestContextResolverTests {
29+
30+
private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO";
31+
private static final String TEMPLATE = "template";
32+
private static final String REGISTRATION_ID = "registration-id";
33+
private static final String IDP_ENTITY_ID = "idp-entity-id";
34+
35+
private MockHttpServletRequest request;
36+
private RelyingPartyRegistration.Builder rpBuilder;
37+
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();
38+
39+
@Before
40+
public void setup() {
41+
request = new MockHttpServletRequest();
42+
rpBuilder = RelyingPartyRegistration
43+
.withRegistrationId(REGISTRATION_ID)
44+
.providerDetails(c -> c.entityId(IDP_ENTITY_ID))
45+
.providerDetails(c -> c.webSsoUrl(IDP_SSO_URL))
46+
.assertionConsumerServiceUrlTemplate(TEMPLATE)
47+
.credentials(c -> c.add(signingCredential()));
48+
}
49+
50+
@Test
51+
public void resoleWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
52+
Saml2AuthenticationRequestContext authenticationRequestContext = authenticationRequestContextResolver.resolve(request, rpBuilder.build());
53+
54+
assertThat(authenticationRequestContext).isNotNull();
55+
assertThat(authenticationRequestContext.getAssertionConsumerServiceUrl()).isEqualTo(TEMPLATE);
56+
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(REGISTRATION_ID);
57+
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getProviderDetails().getEntityId()).isEqualTo(IDP_ENTITY_ID);
58+
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getProviderDetails().getWebSsoUrl()).isEqualTo(IDP_SSO_URL);
59+
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getCredentials()).isNotEmpty();
60+
}
61+
62+
@Test(expected = IllegalArgumentException.class)
63+
public void resolveWhenRequestAndRelyingPartyNullThenException() {
64+
authenticationRequestContextResolver.resolve(null, null);
65+
}
66+
}

0 commit comments

Comments
 (0)