Skip to content

Commit 5061ae9

Browse files
committed
Add Saml2AuthenticationTokenConverter
Closes gh-8768
1 parent a10c2c6 commit 5061ae9

File tree

9 files changed

+386
-49
lines changed

9 files changed

+386
-49
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

+27-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.security.config.annotation.web.configurers.AbstractAuthenticationFilterConfigurer;
2929
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
3030
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
31+
import org.springframework.security.core.Authentication;
3132
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
3233
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory;
3334
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
@@ -38,6 +39,8 @@
3839
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
3940
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
4041
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
42+
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
43+
import org.springframework.security.web.authentication.AuthenticationConverter;
4144
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
4245
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
4346
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -106,10 +109,25 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
106109

107110
private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
108111

112+
private AuthenticationConverter authenticationConverter;
109113
private AuthenticationManager authenticationManager;
110114

111115
private Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter;
112116

117+
/**
118+
* Use this {@link AuthenticationConverter} when converting incoming requests to an {@link Authentication}.
119+
* By default the {@link Saml2AuthenticationTokenConverter} is used.
120+
*
121+
* @param authenticationConverter the {@link AuthenticationConverter} to use
122+
* @return the {@link Saml2LoginConfigurer} for further configuration
123+
* @since 5.4
124+
*/
125+
public Saml2LoginConfigurer<B> authenticationConverter(AuthenticationConverter authenticationConverter) {
126+
Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
127+
this.authenticationConverter = authenticationConverter;
128+
return this;
129+
}
130+
113131
/**
114132
* Allows a configuration of a {@link AuthenticationManager} to be used during SAML 2 authentication.
115133
* If none is specified, the system will create one inject it into the {@link Saml2WebSsoAuthenticationFilter}
@@ -187,7 +205,7 @@ public void init(B http) throws Exception {
187205
}
188206

189207
saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(
190-
this.relyingPartyRegistrationRepository,
208+
getAuthenticationConverter(http),
191209
this.loginProcessingUrl
192210
);
193211
setAuthenticationFilter(saml2WebSsoAuthenticationFilter);
@@ -241,6 +259,14 @@ public void configure(B http) throws Exception {
241259
}
242260
}
243261

262+
private AuthenticationConverter getAuthenticationConverter(B http) {
263+
if (this.authenticationConverter == null) {
264+
return new Saml2AuthenticationTokenConverter(
265+
new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository));
266+
}
267+
return this.authenticationConverter;
268+
}
269+
244270
private void registerDefaultAuthenticationProvider(B http) {
245271
OpenSamlAuthenticationProvider provider = postProcess(new OpenSamlAuthenticationProvider());
246272
http.authenticationProvider(provider);

config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

+42
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,12 @@
6565
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
6666
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
6767
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
68+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
6869
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
6970
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
7071
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
7172
import org.springframework.security.web.FilterChainProxy;
73+
import org.springframework.security.web.authentication.AuthenticationConverter;
7274
import org.springframework.security.web.context.HttpRequestResponseHolder;
7375
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
7476
import org.springframework.security.web.context.SecurityContextRepository;
@@ -86,9 +88,13 @@
8688
import static org.mockito.Mockito.verify;
8789
import static org.mockito.Mockito.when;
8890
import static org.springframework.security.config.Customizer.withDefaults;
91+
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
8992
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
93+
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
9094
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
9195
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
96+
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
97+
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
9298
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
9399

94100
/**
@@ -101,6 +107,8 @@ public class Saml2LoginConfigurerTests {
101107
private static final GrantedAuthoritiesMapper AUTHORITIES_MAPPER =
102108
authorities -> Arrays.asList(new SimpleGrantedAuthority("TEST CONVERTED"));
103109
private static final Duration RESPONSE_TIME_VALIDATION_SKEW = Duration.ZERO;
110+
private static final String SIGNED_RESPONSE =
111+
"PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz48c2FtbDJwOlJlc3BvbnNlIHhtbG5zOnNhbWwycD0idXJuOm9hc2lzOm5hbWVzOnRjOlNBTUw6Mi4wOnByb3RvY29sIiBEZXN0aW5hdGlvbj0iaHR0cHM6Ly9ycC5leGFtcGxlLm9yZy9hY3MiIElEPSJfYzE3MzM2YTAtNTM1My00MTQ5LWI3MmMtMDNkOWY5YWYzMDdlIiBJc3N1ZUluc3RhbnQ9IjIwMjAtMDgtMDRUMjI6MDQ6NDUuMDE2WiIgVmVyc2lvbj0iMi4wIj48c2FtbDI6SXNzdWVyIHhtbG5zOnNhbWwyPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6YXNzZXJ0aW9uIj5hcC1lbnRpdHktaWQ8L3NhbWwyOklzc3Vlcj48ZHM6U2lnbmF0dXJlIHhtbG5zOmRzPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwLzA5L3htbGRzaWcjIj4KPGRzOlNpZ25lZEluZm8+CjxkczpDYW5vbmljYWxpemF0aW9uTWV0aG9kIEFsZ29yaXRobT0iaHR0cDovL3d3dy53My5vcmcvMjAwMS8xMC94bWwtZXhjLWMxNG4jIi8+CjxkczpTaWduYXR1cmVNZXRob2QgQWxnb3JpdGhtPSJodHRwOi8vd3d3LnczLm9yZy8yMDAxLzA0L3htbGRzaWctbW9yZSNyc2Etc2hhMjU2Ii8+CjxkczpSZWZlcmVuY2UgVVJJPSIjX2MxNzMzNmEwLTUzNTMtNDE0OS1iNzJjLTAzZDlmOWFmMzA3ZSI+CjxkczpUcmFuc2Zvcm1zPgo8ZHM6VHJhbnNmb3JtIEFsZ29yaXRobT0iaHR0cDovL3d3dy53My5vcmcvMjAwMC8wOS94bWxkc2lnI2VudmVsb3BlZC1zaWduYXR1cmUiLz4KPGRzOlRyYW5zZm9ybSBBbGdvcml0aG09Imh0dHA6Ly93d3cudzMub3JnLzIwMDEvMTAveG1sLWV4Yy1jMTRuIyIvPgo8L2RzOlRyYW5zZm9ybXM+CjxkczpEaWdlc3RNZXRob2QgQWxnb3JpdGhtPSJodHRwOi8vd3d3LnczLm9yZy8yMDAxLzA0L3htbGVuYyNzaGEyNTYiLz4KPGRzOkRpZ2VzdFZhbHVlPjYzTmlyenFzaDVVa0h1a3NuRWUrM0hWWU5aYWFsQW1OQXFMc1lGMlRuRDA9PC9kczpEaWdlc3RWYWx1ZT4KPC9kczpSZWZlcmVuY2U+CjwvZHM6U2lnbmVkSW5mbz4KPGRzOlNpZ25hdHVyZVZhbHVlPgpLMVlvWWJVUjBTclY4RTdVMkhxTTIvZUNTOTNoV25mOExnNnozeGZWMUlyalgzSXhWYkNvMVlYcnRBSGRwRVdvYTJKKzVOMmFNbFBHJiMxMzsKN2VpbDBZRC9xdUVRamRYbTNwQTBjZmEvY25pa2RuKzVhbnM0ZWQwanU1amo2dkpvZ2w2Smt4Q25LWUpwTU9HNzhtampmb0phengrWCYjMTM7CkM2NktQVStBYUdxeGVwUEQ1ZlhRdTFKSy9Jb3lBaitaa3k4Z2Jwc3VyZHFCSEJLRWxjdnVOWS92UGY0OGtBeFZBKzdtRGhNNUMvL1AmIzEzOwp0L084Y3NZYXB2UjZjdjZrdk45QXZ1N3FRdm9qVk1McHVxZWNJZDJwTUVYb0NSSnE2Nkd4MStNTUVPeHVpMWZZQlRoMEhhYjRmK3JyJiMxMzsKOEY2V1NFRC8xZllVeHliRkJqZ1Q4d2lEWHFBRU8wSVY4ZWRQeEE9PQo8L2RzOlNpZ25hdHVyZVZhbHVlPgo8L2RzOlNpZ25hdHVyZT48c2FtbDI6QXNzZXJ0aW9uIHhtbG5zOnNhbWwyPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6YXNzZXJ0aW9uIiBJRD0iQWUzZjQ5OGI4LTliMTctNDA3OC05ZDM1LTg2YTA4NDA4NDk5NSIgSXNzdWVJbnN0YW50PSIyMDIwLTA4LTA0VDIyOjA0OjQ1LjA3N1oiIFZlcnNpb249IjIuMCI+PHNhbWwyOklzc3Vlcj5hcC1lbnRpdHktaWQ8L3NhbWwyOklzc3Vlcj48c2FtbDI6U3ViamVjdD48c2FtbDI6TmFtZUlEPnRlc3RAc2FtbC51c2VyPC9zYW1sMjpOYW1lSUQ+PHNhbWwyOlN1YmplY3RDb25maXJtYXRpb24gTWV0aG9kPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6Y206YmVhcmVyIj48c2FtbDI6U3ViamVjdENvbmZpcm1hdGlvbkRhdGEgTm90QmVmb3JlPSIyMDIwLTA4LTA0VDIxOjU5OjQ1LjA5MFoiIE5vdE9uT3JBZnRlcj0iMjA0MC0wNy0zMFQyMjowNTowNi4wODhaIiBSZWNpcGllbnQ9Imh0dHBzOi8vcnAuZXhhbXBsZS5vcmcvYWNzIi8+PC9zYW1sMjpTdWJqZWN0Q29uZmlybWF0aW9uPjwvc2FtbDI6U3ViamVjdD48c2FtbDI6Q29uZGl0aW9ucyBOb3RCZWZvcmU9IjIwMjAtMDgtMDRUMjE6NTk6NDUuMDgwWiIgTm90T25PckFmdGVyPSIyMDQwLTA3LTMwVDIyOjA1OjA2LjA4N1oiLz48L3NhbWwyOkFzc2VydGlvbj48L3NhbWwycDpSZXNwb25zZT4=";
104112

105113
@Autowired
106114
private ConfigurableApplicationContext context;
@@ -181,6 +189,23 @@ public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() thro
181189
assertThat(inflated).contains("ForceAuthn=\"true\"");
182190
}
183191

192+
@Test
193+
public void authenticateWhenCustomAuthenticationConverterThenUses() throws Exception {
194+
this.spring.register(CustomAuthenticationConverter.class).autowire();
195+
RelyingPartyRegistration relyingPartyRegistration = noCredentials()
196+
.assertingPartyDetails(party -> party
197+
.verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential()))
198+
)
199+
.build();
200+
String response = new String(samlDecode(SIGNED_RESPONSE));
201+
when(CustomAuthenticationConverter.authenticationConverter.convert(any(HttpServletRequest.class)))
202+
.thenReturn(new Saml2AuthenticationToken(relyingPartyRegistration, response));
203+
this.mvc.perform(post("/login/saml2/sso/" + relyingPartyRegistration.getRegistrationId())
204+
.param("SAMLResponse", SIGNED_RESPONSE))
205+
.andExpect(redirectedUrl("/"));
206+
verify(CustomAuthenticationConverter.authenticationConverter).convert(any(HttpServletRequest.class));
207+
}
208+
184209
private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
185210
// get the OpenSamlAuthenticationProvider
186211
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@@ -311,6 +336,23 @@ Saml2AuthenticationRequestFactory authenticationRequestFactory() {
311336
}
312337
}
313338

339+
@EnableWebSecurity
340+
@Import(Saml2LoginConfigBeans.class)
341+
static class CustomAuthenticationConverter extends WebSecurityConfigurerAdapter {
342+
static final AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);
343+
344+
@Override
345+
protected void configure(HttpSecurity http) throws Exception {
346+
http
347+
.authorizeRequests(authz -> authz
348+
.anyRequest().authenticated()
349+
)
350+
.saml2Login(saml2 -> saml2
351+
.authenticationConverter(authenticationConverter)
352+
);
353+
}
354+
}
355+
314356
private static AuthenticationManager getAuthenticationManagerMock(String role) {
315357
return new AuthenticationManager() {
316358

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java

+29-45
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,19 @@
1919
import javax.servlet.http.HttpServletRequest;
2020
import javax.servlet.http.HttpServletResponse;
2121

22-
import org.springframework.http.HttpMethod;
2322
import org.springframework.security.core.Authentication;
2423
import org.springframework.security.core.AuthenticationException;
25-
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
26-
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
2724
import org.springframework.security.saml2.core.Saml2Error;
28-
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
25+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
2926
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
27+
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
28+
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
3029
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
30+
import org.springframework.security.web.authentication.AuthenticationConverter;
3131
import org.springframework.security.web.authentication.session.ChangeSessionIdAuthenticationStrategy;
32-
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
33-
import org.springframework.security.web.util.matcher.RequestMatcher;
3432
import org.springframework.util.Assert;
3533

36-
import static java.nio.charset.StandardCharsets.UTF_8;
3734
import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND;
38-
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
3935
import static org.springframework.util.StringUtils.hasText;
4036

4137
/**
@@ -44,8 +40,7 @@
4440
public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
4541

4642
public static final String DEFAULT_FILTER_PROCESSES_URI = "/login/saml2/sso/{registrationId}";
47-
private final RequestMatcher matcher;
48-
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
43+
private final AuthenticationConverter authenticationConverter;
4944

5045
/**
5146
* Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is configured
@@ -64,16 +59,30 @@ public Saml2WebSsoAuthenticationFilter(RelyingPartyRegistrationRepository relyin
6459
public Saml2WebSsoAuthenticationFilter(
6560
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
6661
String filterProcessesUrl) {
67-
super(filterProcessesUrl);
68-
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
69-
Assert.hasText(filterProcessesUrl, "filterProcessesUrl must contain a URL pattern");
62+
this(new Saml2AuthenticationTokenConverter
63+
(new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
64+
filterProcessesUrl);
65+
}
66+
67+
/**
68+
* Creates a {@link Saml2WebSsoAuthenticationFilter} given the provided parameters
69+
*
70+
* @param authenticationConverter the strategy for converting an {@link HttpServletRequest}
71+
* into an {@link Authentication}
72+
* @param filterProcessingUrl the processing URL, must contain a {registrationId} variable
73+
* @since 5.4
74+
*/
75+
public Saml2WebSsoAuthenticationFilter(
76+
AuthenticationConverter authenticationConverter,
77+
String filterProcessingUrl) {
78+
super(filterProcessingUrl);
79+
Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
80+
Assert.hasText(filterProcessingUrl, "filterProcessesUrl must contain a URL pattern");
7081
Assert.isTrue(
71-
filterProcessesUrl.contains("{registrationId}"),
82+
filterProcessingUrl.contains("{registrationId}"),
7283
"filterProcessesUrl must contain a {registrationId} match variable"
7384
);
74-
this.matcher = new AntPathRequestMatcher(filterProcessesUrl);
75-
setRequiresAuthenticationRequestMatcher(this.matcher);
76-
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
85+
this.authenticationConverter = authenticationConverter;
7786
setAllowSessionCreation(true);
7887
setSessionAuthenticationStrategy(new ChangeSessionIdAuthenticationStrategy());
7988
}
@@ -86,37 +95,12 @@ protected boolean requiresAuthentication(HttpServletRequest request, HttpServlet
8695
@Override
8796
public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
8897
throws AuthenticationException {
89-
String saml2Response = request.getParameter("SAMLResponse");
90-
byte[] b = Saml2Utils.samlDecode(saml2Response);
91-
92-
String responseXml = inflateIfRequired(request, b);
93-
String registrationId = this.matcher.matcher(request).getVariables().get("registrationId");
94-
RelyingPartyRegistration rp =
95-
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
96-
if (rp == null) {
98+
Authentication authentication = this.authenticationConverter.convert(request);
99+
if (authentication == null) {
97100
Saml2Error saml2Error = new Saml2Error(RELYING_PARTY_REGISTRATION_NOT_FOUND,
98-
"Relying Party Registration not found with ID: " + registrationId);
101+
"No relying party registration found");
99102
throw new Saml2AuthenticationException(saml2Error);
100103
}
101-
String applicationUri = Saml2ServletUtils.getApplicationUri(request);
102-
String relyingPartyEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp);
103-
String assertionConsumerServiceLocation = Saml2ServletUtils.resolveUrlTemplate(
104-
rp.getAssertionConsumerServiceLocation(), applicationUri, rp);
105-
RelyingPartyRegistration relyingPartyRegistration = withRelyingPartyRegistration(rp)
106-
.entityId(relyingPartyEntityId)
107-
.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
108-
.build();
109-
Saml2AuthenticationToken authentication = new Saml2AuthenticationToken(
110-
relyingPartyRegistration, responseXml);
111104
return getAuthenticationManager().authenticate(authentication);
112105
}
113-
114-
private String inflateIfRequired(HttpServletRequest request, byte[] b) {
115-
if (HttpMethod.GET.matches(request.getMethod())) {
116-
return Saml2Utils.samlInflate(b);
117-
}
118-
else {
119-
return new String(b, UTF_8);
120-
}
121-
}
122106
}

0 commit comments

Comments
 (0)