Skip to content

Support multiple SingleLogoutService bindings. #11287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -147,7 +147,7 @@ public Saml2LogoutConfigurer(ApplicationContext context) {
* <p>
* The Relying Party triggers logout by POSTing to the endpoint. The Asserting Party
* triggers logout based on what is specified by
* {@link RelyingPartyRegistration#getSingleLogoutServiceBinding()}.
* {@link RelyingPartyRegistration#getSingleLogoutServiceBindings()}.
* @param logoutUrl the URL that will invoke logout
* @return the {@link LogoutConfigurer} for further customizations
* @see LogoutConfigurer#logoutUrl(String)
Expand Down Expand Up @@ -343,7 +343,7 @@ public final class LogoutRequestConfigurer {
*
* <p>
* The Asserting Party should use whatever HTTP method specified in
* {@link RelyingPartyRegistration#getSingleLogoutServiceBinding()}.
* {@link RelyingPartyRegistration#getSingleLogoutServiceBindings()}.
* @param logoutUrl the URL that will receive the SAML 2.0 Logout Request
* @return the {@link LogoutRequestConfigurer} for further customizations
* @see Saml2LogoutConfigurer#logoutUrl(String)
Expand Down Expand Up @@ -425,7 +425,7 @@ public final class LogoutResponseConfigurer {
*
* <p>
* The Asserting Party should use whatever HTTP method specified in
* {@link RelyingPartyRegistration#getSingleLogoutServiceBinding()}.
* {@link RelyingPartyRegistration#getSingleLogoutServiceBindings()}.
* @param logoutUrl the URL that will receive the SAML 2.0 Logout Response
* @return the {@link LogoutResponseConfigurer} for further customizations
* @see Saml2LogoutConfigurer#logoutUrl(String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -104,7 +105,9 @@ private SPSSODescriptor buildSpSsoDescriptor(RelyingPartyRegistration registrati
.addAll(buildKeys(registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION));
spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration));
if (registration.getSingleLogoutServiceLocation() != null) {
spSsoDescriptor.getSingleLogoutServices().add(buildSingleLogoutService(registration));
for (Saml2MessageBinding binding : registration.getSingleLogoutServiceBindings()) {
spSsoDescriptor.getSingleLogoutServices().add(buildSingleLogoutService(registration, binding));
}
}
if (registration.getNameIdFormat() != null) {
spSsoDescriptor.getNameIDFormats().add(buildNameIDFormat(registration));
Expand Down Expand Up @@ -147,11 +150,12 @@ private AssertionConsumerService buildAssertionConsumerService(RelyingPartyRegis
return assertionConsumerService;
}

private SingleLogoutService buildSingleLogoutService(RelyingPartyRegistration registration) {
private SingleLogoutService buildSingleLogoutService(RelyingPartyRegistration registration,
Saml2MessageBinding binding) {
SingleLogoutService singleLogoutService = build(SingleLogoutService.DEFAULT_ELEMENT_NAME);
singleLogoutService.setLocation(registration.getSingleLogoutServiceLocation());
singleLogoutService.setResponseLocation(registration.getSingleLogoutServiceResponseLocation());
singleLogoutService.setBinding(registration.getSingleLogoutServiceBinding().getUrn());
singleLogoutService.setBinding(binding.getUrn());
return singleLogoutService;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/**
* Represents a configured relying party (aka Service Provider) and asserting party (aka
Expand Down Expand Up @@ -81,7 +82,7 @@ public final class RelyingPartyRegistration {

private final String singleLogoutServiceResponseLocation;

private final Saml2MessageBinding singleLogoutServiceBinding;
private final Collection<Saml2MessageBinding> singleLogoutServiceBindings;

private final String nameIdFormat;

Expand All @@ -93,16 +94,16 @@ public final class RelyingPartyRegistration {

private RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation,
Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation,
String singleLogoutServiceResponseLocation, Saml2MessageBinding singleLogoutServiceBinding,
String singleLogoutServiceResponseLocation, Collection<Saml2MessageBinding> singleLogoutServiceBindings,
AssertingPartyDetails assertingPartyDetails, String nameIdFormat,
Collection<Saml2X509Credential> decryptionX509Credentials,
Collection<Saml2X509Credential> signingX509Credentials) {
Assert.hasText(registrationId, "registrationId cannot be empty");
Assert.hasText(entityId, "entityId cannot be empty");
Assert.hasText(assertionConsumerServiceLocation, "assertionConsumerServiceLocation cannot be empty");
Assert.notNull(assertionConsumerServiceBinding, "assertionConsumerServiceBinding cannot be null");
Assert.isTrue(singleLogoutServiceLocation == null || singleLogoutServiceBinding != null,
"singleLogoutServiceBinding cannot be null when singleLogoutServiceLocation is set");
Assert.isTrue(singleLogoutServiceLocation == null || !CollectionUtils.isEmpty(singleLogoutServiceBindings),
"singleLogoutServiceBindings cannot be null or empty when singleLogoutServiceLocation is set");
Assert.notNull(assertingPartyDetails, "assertingPartyDetails cannot be null");
Assert.notNull(decryptionX509Credentials, "decryptionX509Credentials cannot be null");
for (Saml2X509Credential c : decryptionX509Credentials) {
Expand All @@ -121,7 +122,7 @@ private RelyingPartyRegistration(String registrationId, String entityId, String
this.assertionConsumerServiceBinding = assertionConsumerServiceBinding;
this.singleLogoutServiceLocation = singleLogoutServiceLocation;
this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation;
this.singleLogoutServiceBinding = singleLogoutServiceBinding;
this.singleLogoutServiceBindings = Collections.unmodifiableList(new LinkedList<>(singleLogoutServiceBindings));
this.nameIdFormat = nameIdFormat;
this.assertingPartyDetails = assertingPartyDetails;
this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials));
Expand Down Expand Up @@ -194,7 +195,22 @@ public Saml2MessageBinding getAssertionConsumerServiceBinding() {
* @since 5.6
*/
public Saml2MessageBinding getSingleLogoutServiceBinding() {
return this.singleLogoutServiceBinding;
Assert.state(this.singleLogoutServiceBindings.size() == 1, "Method does not support multiple bindings.");
return this.singleLogoutServiceBindings.iterator().next();
}

/**
* Get the <a href=
* "https://docs.oasis-open.org/security/saml/v2.0/saml-metadata-2.0-os.pdf#page=7">SingleLogoutService
* Binding</a>
* <p>
* Equivalent to the value found in &lt;SingleLogoutService Binding="..."/&gt; in the
* relying party's &lt;SPSSODescriptor&gt;.
* @return the SingleLogoutService Binding
* @since 5.8
*/
public Collection<Saml2MessageBinding> getSingleLogoutServiceBindings() {
return this.singleLogoutServiceBindings;
}

/**
Expand Down Expand Up @@ -308,7 +324,7 @@ public static Builder withRelyingPartyRegistration(RelyingPartyRegistration regi
.assertionConsumerServiceBinding(registration.getAssertionConsumerServiceBinding())
.singleLogoutServiceLocation(registration.getSingleLogoutServiceLocation())
.singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation())
.singleLogoutServiceBinding(registration.getSingleLogoutServiceBinding())
.singleLogoutServiceBindings((c) -> c.addAll(registration.getSingleLogoutServiceBindings()))
.nameIdFormat(registration.getNameIdFormat())
.assertingPartyDetails((assertingParty) -> assertingParty
.entityId(registration.getAssertingPartyDetails().getEntityId())
Expand Down Expand Up @@ -737,7 +753,7 @@ public static final class Builder {

private String singleLogoutServiceResponseLocation;

private Saml2MessageBinding singleLogoutServiceBinding = Saml2MessageBinding.POST;
private Collection<Saml2MessageBinding> singleLogoutServiceBindings = new LinkedHashSet<>();

private String nameIdFormat = null;

Expand Down Expand Up @@ -855,7 +871,28 @@ public Builder assertionConsumerServiceBinding(Saml2MessageBinding assertionCons
* @since 5.6
*/
public Builder singleLogoutServiceBinding(Saml2MessageBinding singleLogoutServiceBinding) {
this.singleLogoutServiceBinding = singleLogoutServiceBinding;
return this.singleLogoutServiceBindings((saml2MessageBindings) -> {
saml2MessageBindings.clear();
saml2MessageBindings.add(singleLogoutServiceBinding);
});
}

/**
* Apply this {@link Consumer} to the {@link Collection} of
* {@link Saml2MessageBinding}s for the purposes of modifying the <a href=
* "https://docs.oasis-open.org/security/saml/v2.0/saml-metadata-2.0-os.pdf#page=7">SingleLogoutService
* Binding</a> {@link Collection}.
*
* <p>
* Equivalent to the value found in &lt;SingleLogoutService Binding="..."/&gt; in
* the relying party's &lt;SPSSODescriptor&gt;.
* @param bindingsConsumer - the {@link Consumer} for modifying the
* {@link Collection}
* @return the {@link Builder} for further configuration
* @since 5.8
*/
public Builder singleLogoutServiceBindings(Consumer<Collection<Saml2MessageBinding>> bindingsConsumer) {
bindingsConsumer.accept(this.singleLogoutServiceBindings);
return this;
}

Expand Down Expand Up @@ -925,10 +962,15 @@ public RelyingPartyRegistration build() {
if (this.singleLogoutServiceResponseLocation == null) {
this.singleLogoutServiceResponseLocation = this.singleLogoutServiceLocation;
}

if (this.singleLogoutServiceBindings.isEmpty()) {
this.singleLogoutServiceBindings.add(Saml2MessageBinding.POST);
}

return new RelyingPartyRegistration(this.registrationId, this.entityId,
this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding,
this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation,
this.singleLogoutServiceBinding, this.assertingPartyDetailsBuilder.build(), this.nameIdFormat,
this.singleLogoutServiceBindings, this.assertingPartyDetailsBuilder.build(), this.nameIdFormat,
this.decryptionX509Credentials, this.signingX509Credentials);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -134,9 +134,7 @@ Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentic
if (registration.getAssertingPartyDetails().getSingleLogoutServiceResponseLocation() == null) {
return null;
}
String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST);
byte[] b = Saml2Utils.samlDecode(serialized);
LogoutRequest logoutRequest = parse(inflateIfRequired(registration, b));
LogoutRequest logoutRequest = parse(extractSamlRequest(request));
LogoutResponse logoutResponse = this.logoutResponseBuilder.buildObject();
logoutResponse.setDestination(registration.getAssertingPartyDetails().getSingleLogoutServiceResponseLocation());
Issuer issuer = this.issuerBuilder.buildObject();
Expand Down Expand Up @@ -189,8 +187,10 @@ private String getRegistrationId(Authentication authentication) {
return null;
}

private String inflateIfRequired(RelyingPartyRegistration registration, byte[] b) {
if (registration.getSingleLogoutServiceBinding() == Saml2MessageBinding.REDIRECT) {
private String extractSamlRequest(HttpServletRequest request) {
String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST);
byte[] b = Saml2Utils.samlDecode(serialized);
if (Saml2MessageBindingUtils.isHttpRedirectBinding(request)) {
return Saml2Utils.samlInflate(b);
}
return new String(b, StandardCharsets.UTF_8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
return;
}
if (!isCorrectBinding(request, registration)) {

Saml2MessageBinding saml2MessageBinding = Saml2MessageBindingUtils.resolveBinding(request);
if (!registration.getSingleLogoutServiceBindings().contains(saml2MessageBinding)) {
this.logger.trace("Did not process logout request since used incorrect binding");
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
return;
Expand All @@ -131,8 +133,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST);
Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
.samlRequest(serialized).relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE))
.binding(registration.getSingleLogoutServiceBinding())
.location(registration.getSingleLogoutServiceLocation())
.binding(saml2MessageBinding).location(registration.getSingleLogoutServiceLocation())
.parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG,
request.getParameter(Saml2ParameterNames.SIG_ALG)))
.parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE,
Expand Down Expand Up @@ -177,14 +178,6 @@ private String getRegistrationId(Authentication authentication) {
return null;
}

private boolean isCorrectBinding(HttpServletRequest request, RelyingPartyRegistration registration) {
Saml2MessageBinding requiredBinding = registration.getSingleLogoutServiceBinding();
if (requiredBinding == Saml2MessageBinding.POST) {
return "POST".equals(request.getMethod());
}
return "GET".equals(request.getMethod());
}

private void doRedirect(HttpServletRequest request, HttpServletResponse response,
Saml2LogoutResponse logoutResponse) throws IOException {
String location = logoutResponse.getResponseLocation();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,18 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
return;
}
if (!isCorrectBinding(request, registration)) {
this.logger.trace("Did not process logout request since used incorrect binding");

Saml2MessageBinding saml2MessageBinding = Saml2MessageBindingUtils.resolveBinding(request);
if (!registration.getSingleLogoutServiceBindings().contains(saml2MessageBinding)) {
this.logger.trace("Did not process logout response since used incorrect binding");
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
return;
}

String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
Saml2LogoutResponse logoutResponse = Saml2LogoutResponse.withRelyingPartyRegistration(registration)
.samlResponse(serialized).relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE))
.binding(registration.getSingleLogoutServiceBinding())
.location(registration.getSingleLogoutServiceResponseLocation())
.binding(saml2MessageBinding).location(registration.getSingleLogoutServiceResponseLocation())
.parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG,
request.getParameter(Saml2ParameterNames.SIG_ALG)))
.parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE,
Expand Down Expand Up @@ -167,12 +168,4 @@ public void setLogoutRequestRepository(Saml2LogoutRequestRepository logoutReques
this.logoutRequestRepository = logoutRequestRepository;
}

private boolean isCorrectBinding(HttpServletRequest request, RelyingPartyRegistration registration) {
Saml2MessageBinding requiredBinding = registration.getSingleLogoutServiceBinding();
if (requiredBinding == Saml2MessageBinding.POST) {
return "POST".equals(request.getMethod());
}
return "GET".equals(request.getMethod());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.saml2.provider.service.web.authentication.logout;

import jakarta.servlet.http.HttpServletRequest;

import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;

/**
* Utility methods for working with {@link Saml2MessageBinding}
*
* For internal use only.
*
* @since 5.8
*/
final class Saml2MessageBindingUtils {

private Saml2MessageBindingUtils() {
}

static Saml2MessageBinding resolveBinding(HttpServletRequest request) {
if (isHttpPostBinding(request)) {
return Saml2MessageBinding.POST;
}
else if (isHttpRedirectBinding(request)) {
return Saml2MessageBinding.REDIRECT;
}
throw new Saml2Exception("Unable to determine message binding from request.");
}

private static boolean isSamlRequestResponse(HttpServletRequest request) {
return (request.getParameter(Saml2ParameterNames.SAML_REQUEST) != null
|| request.getParameter(Saml2ParameterNames.SAML_RESPONSE) != null);
}

static boolean isHttpRedirectBinding(HttpServletRequest request) {
return request != null && "GET".equalsIgnoreCase(request.getMethod()) && isSamlRequestResponse(request);
}

static boolean isHttpPostBinding(HttpServletRequest request) {
return request != null && "POST".equalsIgnoreCase(request.getMethod()) && isSamlRequestResponse(request);
}

}