Skip to content

Add HSM Support for Decrypting Assertions #9055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
* asserting party, IDP, verification certificates.
* </p>
*
* @author Ryan Cassar
* @since 5.2
* @see <a href=
* "https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38">SAML 2
Expand Down Expand Up @@ -211,6 +212,32 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi

private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();

private Consumer<ResponseToken> assertionDecrypter = (responseToken) -> {
List<Assertion> assertions = new ArrayList<>();
for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) {
try {
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
Assertion assertion = decrypter.decrypt(encryptedAssertion);
assertions.add(assertion);
}
catch (DecryptionException ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
}
responseToken.getResponse().getAssertions().addAll(assertions);
};

private Consumer<ResponseToken> principalDecrypter = (responseToken) -> {
try {
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
}
catch (DecryptionException ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
};

/**
* Creates an {@link OpenSamlAuthenticationProvider}
*/
Expand Down Expand Up @@ -332,6 +359,52 @@ public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) {
this.responseTimeValidationSkew = responseTimeValidationSkew;
}

/**
* Sets the assertion response custom decrypter.
*
* You can use this method like so:
*
* <pre>
* YourDecrypter decrypter = // ... your custom decrypter
*
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* provider.setAssertionDecrypter((responseToken) -> {
* Response response = responseToken.getResponse();
* EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
* Assertion assertion = decrypter.decrypt(encrypted);
* response.getAssertions().add(assertion);
* });
* </pre>
* @param assertionDecrypter response token consumer
*/
public void setAssertionDecrypter(Consumer<ResponseToken> assertionDecrypter) {
Assert.notNull(assertionDecrypter, "Consumer<ResponseToken> required");
this.assertionDecrypter = assertionDecrypter;
}

/**
* Sets the principal custom decrypter.
*
* You can use this method like so:
*
* <pre>
* YourDecrypter decrypter = // ... your custom decrypter
*
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* provider.setAssertionDecrypter((responseToken) -> {
* Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
* EncryptedID encrypted = assertion.getSubject().getEncryptedID();
* NameID name = decrypter.decrypt(encrypted);
* assertion.getSubject().setNameID(name)
* });
* </pre>
* @param principalDecrypter response token consumer
*/
public void setPrincipalDecrypter(Consumer<ResponseToken> principalDecrypter) {
Assert.notNull(principalDecrypter, "Consumer<ResponseToken> required");
this.principalDecrypter = principalDecrypter;
}

/**
* Construct a default strategy for validating each SAML 2.0 Assertion and associated
* {@link Authentication} token
Expand Down Expand Up @@ -429,8 +502,8 @@ private void process(Saml2AuthenticationToken token, Response response) {
boolean responseSigned = response.isSigned();
Saml2ResponseValidatorResult result = validateResponse(token, response);

Decrypter decrypter = this.decrypterConverter.convert(token);
List<Assertion> assertions = decryptAssertions(decrypter, response);
ResponseToken responseToken = new ResponseToken(response, token);
List<Assertion> assertions = decryptAssertions(responseToken);
if (!isSigned(responseSigned, assertions)) {
String description = "Either the response or one of the assertions is unsigned. "
+ "Please either sign the response or all of the assertions.";
Expand All @@ -439,7 +512,7 @@ private void process(Saml2AuthenticationToken token, Response response) {
result = result.concat(validateAssertions(token, response));

Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
NameID nameId = decryptPrincipal(decrypter, firstAssertion);
NameID nameId = decryptPrincipal(responseToken);
if (nameId == null || nameId.getValue() == null) {
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
"Assertion [" + firstAssertion.getID() + "] is missing a subject");
Expand Down Expand Up @@ -511,19 +584,9 @@ private Saml2ResponseValidatorResult validateResponse(Saml2AuthenticationToken t
return Saml2ResponseValidatorResult.failure(errors);
}

private List<Assertion> decryptAssertions(Decrypter decrypter, Response response) {
List<Assertion> assertions = new ArrayList<>();
for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
try {
Assertion assertion = decrypter.decrypt(encryptedAssertion);
assertions.add(assertion);
}
catch (DecryptionException ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
}
response.getAssertions().addAll(assertions);
return response.getAssertions();
private List<Assertion> decryptAssertions(ResponseToken response) {
this.assertionDecrypter.accept(response);
return response.getResponse().getAssertions();
}

private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) {
Expand Down Expand Up @@ -567,21 +630,16 @@ private boolean isSigned(boolean responseSigned, List<Assertion> assertions) {
return true;
}

private NameID decryptPrincipal(Decrypter decrypter, Assertion assertion) {
private NameID decryptPrincipal(ResponseToken responseToken) {
Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
if (assertion.getSubject() == null) {
return null;
}
if (assertion.getSubject().getEncryptedID() == null) {
return assertion.getSubject().getNameID();
}
try {
NameID nameId = (NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID());
assertion.getSubject().setNameID(nameId);
return nameId;
}
catch (DecryptionException ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
this.principalDecrypter.accept(responseToken);
return assertion.getSubject().getNameID();
}

private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.ResponseToken;
import org.springframework.util.StringUtils;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -446,17 +447,15 @@ public void setAssertionValidatorWhenNullThenIllegalArgument() {
public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() {
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
OpenSamlAuthenticationProvider.ResponseToken responseToken = new OpenSamlAuthenticationProvider.ResponseToken(
response, token);
ResponseToken responseToken = new ResponseToken(response, token);
Saml2Authentication authentication = OpenSamlAuthenticationProvider
.createDefaultResponseAuthenticationConverter().convert(responseToken);
assertThat(authentication.getName()).isEqualTo("[email protected]");
}

@Test
public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() {
Converter<OpenSamlAuthenticationProvider.ResponseToken, Saml2Authentication> authenticationConverter = mock(
Converter.class);
Converter<ResponseToken, Saml2Authentication> authenticationConverter = mock(Converter.class);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setResponseAuthenticationConverter(authenticationConverter);
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
Expand All @@ -473,6 +472,57 @@ public void setResponseAuthenticationConverterWhenNullThenIllegalArgument() {
// @formatter:on
}

@Test
public void setAssertionDecrypterWhenNullThenIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionDecrypter(null));
}

@Test
public void setPrincipalDecrypterWhenNullThenIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setPrincipalDecrypter(null));
}

@Test
public void setAssertionDecrypterThenChangesAssertion() {
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getSubjectConfirmations()
.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.setAssertionDecrypter(mockAssertionAndPrincipalDecrypter());
assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
}

@Test
public void setPrincipalDecrypterThenChangesAssertion() {
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getSubjectConfirmations()
.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.setPrincipalDecrypter(mockAssertionAndPrincipalDecrypter());
this.provider.authenticate(token);
assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
}

private Consumer<ResponseToken> mockAssertionAndPrincipalDecrypter() {
return (responseToken) -> {
responseToken.getResponse().getAssertions().clear();
responseToken.getResponse().getAssertions()
.add(TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion("1", "2", "3", "4"),
TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID));
};
}

private <T extends XMLObject> T build(QName qName) {
return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
}
Expand Down