diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index d83830e5939..37ba024cb98 100644 --- a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -147,7 +147,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv private Converter responseValidator = createDefaultResponseValidator(); - private final Converter assertionSignatureValidator = createDefaultAssertionSignatureValidator(); + private Converter assertionSignatureValidator = createDefaultAssertionSignatureValidator(); private Consumer assertionElementsDecrypter = createDefaultAssertionElementsDecrypter(); @@ -235,6 +235,29 @@ public void setResponseValidator(Converter + * OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider(); + * provider.setAssertionSignatureValidator(assertionToken -> { + * Saml2ResponseValidatorResult result = createDefaultAssertionSignatureValidator() + * .convert(assertionToken) + * return result.concat(myCustomValidator.convert(assertionToken)); + * }); + * + * @param assertionSignatureValidator the {@link Converter} to use + * @since 5.6 + */ + public void setAssertionSignatureValidator( + Converter assertionSignatureValidator) { + Assert.notNull(assertionSignatureValidator, "assertionSignatureValidator cannot be null"); + this.assertionSignatureValidator = assertionSignatureValidator; + } + /** * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML * 2.0 Response. @@ -386,6 +409,32 @@ public static Converter createDefau }; } + /** + * Construct a default strategy for validating each SAML 2.0 Assertion Signature + * @return the default assertion signature validator strategy + * @since 5.6 + */ + public static Converter createDefaultAssertionSignatureValidator() { + return createDefaultAssertionSignatureValidator((assertionToken) -> new ValidationContext( + Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false))); + } + + /** + * Construct a default strategy for validating each SAML 2.0 Assertion Signature + * @param contextConverter the conversion strategy to use to generate a + * {@link ValidationContext} for each assertion being validated + * @return the default assertion signature validator strategy + * @since 5.6 + */ + public static Converter createDefaultAssertionSignatureValidator( + Converter contextConverter) { + return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> { + RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration(); + SignatureTrustEngine engine = OpenSamlVerificationUtils.trustEngine(registration); + return SAML20AssertionValidators.createSignatureValidator(engine); + }, contextConverter); + } + /** * Construct a default strategy for validating each SAML 2.0 Assertion and associated * {@link Authentication} token @@ -560,15 +609,6 @@ private static String getStatusCode(Response response) { return response.getStatus().getStatusCode().getValue(); } - private Converter createDefaultAssertionSignatureValidator() { - return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> { - RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration(); - SignatureTrustEngine engine = OpenSamlVerificationUtils.trustEngine(registration); - return SAML20AssertionValidators.createSignatureValidator(engine); - }, (assertionToken) -> new ValidationContext( - Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false))); - } - private Consumer createDefaultAssertionElementsDecrypter() { return (assertionToken) -> { Assertion assertion = assertionToken.getAssertion(); diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index 46724540fd7..9facd310945 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -613,6 +613,34 @@ public void authenticateWhenCustomResponseValidatorThenUses() { verify(validator).convert(any(OpenSaml4AuthenticationProvider.ResponseToken.class)); } + @Test + public void setAssertionSignatureValidatorWhenNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionSignatureValidator(null)); + } + + @Test + public void authenticateWhenCustomAssertionSignatureValidatorThenUses() { + Converter validator = mock( + Converter.class); + OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider(); + // @formatter:off + provider.setAssertionSignatureValidator((responseToken) -> OpenSaml4AuthenticationProvider.createDefaultAssertionSignatureValidator() + .convert(responseToken) + .concat(validator.convert(responseToken)) + ); + // @formatter:on + Response response = response(); + Assertion assertion = assertion(); + response.getAssertions().add(assertion); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + ASSERTING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, verifying(registration())); + given(validator.convert(any(OpenSaml4AuthenticationProvider.AssertionToken.class))) + .willReturn(Saml2ResponseValidatorResult.success()); + provider.authenticate(token); + verify(validator).convert(any(OpenSaml4AuthenticationProvider.AssertionToken.class)); + } + private T build(QName qName) { return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName); }