principalDecrypter = (responseToken) -> {
+ try {
+ Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
+ Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
+ assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
+ }
+ catch (DecryptionException ex) {
+ throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
+ }
+ };
+
/**
* Creates an {@link OpenSamlAuthenticationProvider}
*/
@@ -332,6 +359,52 @@ public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) {
this.responseTimeValidationSkew = responseTimeValidationSkew;
}
+ /**
+ * Sets the assertion response custom decrypter.
+ *
+ * You can use this method like so:
+ *
+ *
+ * YourDecrypter decrypter = // ... your custom decrypter
+ *
+ * OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+ * provider.setAssertionDecrypter((responseToken) -> {
+ * Response response = responseToken.getResponse();
+ * EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
+ * Assertion assertion = decrypter.decrypt(encrypted);
+ * response.getAssertions().add(assertion);
+ * });
+ *
+ * @param assertionDecrypter response token consumer
+ */
+ public void setAssertionDecrypter(Consumer assertionDecrypter) {
+ Assert.notNull(assertionDecrypter, "Consumer required");
+ this.assertionDecrypter = assertionDecrypter;
+ }
+
+ /**
+ * Sets the principal custom decrypter.
+ *
+ * You can use this method like so:
+ *
+ *
+ * YourDecrypter decrypter = // ... your custom decrypter
+ *
+ * OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+ * provider.setAssertionDecrypter((responseToken) -> {
+ * Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
+ * EncryptedID encrypted = assertion.getSubject().getEncryptedID();
+ * NameID name = decrypter.decrypt(encrypted);
+ * assertion.getSubject().setNameID(name)
+ * });
+ *
+ * @param principalDecrypter response token consumer
+ */
+ public void setPrincipalDecrypter(Consumer principalDecrypter) {
+ Assert.notNull(principalDecrypter, "Consumer required");
+ this.principalDecrypter = principalDecrypter;
+ }
+
/**
* Construct a default strategy for validating each SAML 2.0 Assertion and associated
* {@link Authentication} token
@@ -429,8 +502,8 @@ private void process(Saml2AuthenticationToken token, Response response) {
boolean responseSigned = response.isSigned();
Saml2ResponseValidatorResult result = validateResponse(token, response);
- Decrypter decrypter = this.decrypterConverter.convert(token);
- List assertions = decryptAssertions(decrypter, response);
+ ResponseToken responseToken = new ResponseToken(response, token);
+ List assertions = decryptAssertions(responseToken);
if (!isSigned(responseSigned, assertions)) {
String description = "Either the response or one of the assertions is unsigned. "
+ "Please either sign the response or all of the assertions.";
@@ -439,7 +512,7 @@ private void process(Saml2AuthenticationToken token, Response response) {
result = result.concat(validateAssertions(token, response));
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
- NameID nameId = decryptPrincipal(decrypter, firstAssertion);
+ NameID nameId = decryptPrincipal(responseToken);
if (nameId == null || nameId.getValue() == null) {
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
"Assertion [" + firstAssertion.getID() + "] is missing a subject");
@@ -511,19 +584,9 @@ private Saml2ResponseValidatorResult validateResponse(Saml2AuthenticationToken t
return Saml2ResponseValidatorResult.failure(errors);
}
- private List decryptAssertions(Decrypter decrypter, Response response) {
- List assertions = new ArrayList<>();
- for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
- try {
- Assertion assertion = decrypter.decrypt(encryptedAssertion);
- assertions.add(assertion);
- }
- catch (DecryptionException ex) {
- throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
- }
- }
- response.getAssertions().addAll(assertions);
- return response.getAssertions();
+ private List decryptAssertions(ResponseToken response) {
+ this.assertionDecrypter.accept(response);
+ return response.getResponse().getAssertions();
}
private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) {
@@ -567,21 +630,16 @@ private boolean isSigned(boolean responseSigned, List assertions) {
return true;
}
- private NameID decryptPrincipal(Decrypter decrypter, Assertion assertion) {
+ private NameID decryptPrincipal(ResponseToken responseToken) {
+ Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
if (assertion.getSubject() == null) {
return null;
}
if (assertion.getSubject().getEncryptedID() == null) {
return assertion.getSubject().getNameID();
}
- try {
- NameID nameId = (NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID());
- assertion.getSubject().setNameID(nameId);
- return nameId;
- }
- catch (DecryptionException ex) {
- throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
- }
+ this.principalDecrypter.accept(responseToken);
+ return assertion.getSubject().getNameID();
}
private static Map> getAssertionAttributes(Assertion assertion) {
diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java
index e0ebdb6e252..8cc69c17b5b 100644
--- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java
+++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java
@@ -56,6 +56,7 @@
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
+import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.ResponseToken;
import org.springframework.util.StringUtils;
import static org.assertj.core.api.Assertions.assertThat;
@@ -446,8 +447,7 @@ public void setAssertionValidatorWhenNullThenIllegalArgument() {
public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() {
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
- OpenSamlAuthenticationProvider.ResponseToken responseToken = new OpenSamlAuthenticationProvider.ResponseToken(
- response, token);
+ ResponseToken responseToken = new ResponseToken(response, token);
Saml2Authentication authentication = OpenSamlAuthenticationProvider
.createDefaultResponseAuthenticationConverter().convert(responseToken);
assertThat(authentication.getName()).isEqualTo("test@saml.user");
@@ -455,8 +455,7 @@ public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts
@Test
public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() {
- Converter authenticationConverter = mock(
- Converter.class);
+ Converter authenticationConverter = mock(Converter.class);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setResponseAuthenticationConverter(authenticationConverter);
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
@@ -473,6 +472,57 @@ public void setResponseAuthenticationConverterWhenNullThenIllegalArgument() {
// @formatter:on
}
+ @Test
+ public void setAssertionDecrypterWhenNullThenIllegalArgument() {
+ assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionDecrypter(null));
+ }
+
+ @Test
+ public void setPrincipalDecrypterWhenNullThenIllegalArgument() {
+ assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setPrincipalDecrypter(null));
+ }
+
+ @Test
+ public void setAssertionDecrypterThenChangesAssertion() {
+ Response response = TestOpenSamlObjects.response();
+ Assertion assertion = TestOpenSamlObjects.assertion();
+ assertion.getSubject().getSubjectConfirmations()
+ .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
+ TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
+ RELYING_PARTY_ENTITY_ID);
+ response.getAssertions().add(assertion);
+ Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
+ this.provider.setAssertionDecrypter(mockAssertionAndPrincipalDecrypter());
+ assertThatExceptionOfType(Saml2AuthenticationException.class)
+ .isThrownBy(() -> this.provider.authenticate(token))
+ .satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
+ assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
+ }
+
+ @Test
+ public void setPrincipalDecrypterThenChangesAssertion() {
+ Response response = TestOpenSamlObjects.response();
+ Assertion assertion = TestOpenSamlObjects.assertion();
+ assertion.getSubject().getSubjectConfirmations()
+ .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
+ TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
+ RELYING_PARTY_ENTITY_ID);
+ response.getAssertions().add(assertion);
+ Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
+ this.provider.setPrincipalDecrypter(mockAssertionAndPrincipalDecrypter());
+ this.provider.authenticate(token);
+ assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
+ }
+
+ private Consumer mockAssertionAndPrincipalDecrypter() {
+ return (responseToken) -> {
+ responseToken.getResponse().getAssertions().clear();
+ responseToken.getResponse().getAssertions()
+ .add(TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion("1", "2", "3", "4"),
+ TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID));
+ };
+ }
+
private T build(QName qName) {
return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
}