Skip to content

Commit 5b24049

Browse files
author
ryan.cassar
committed
Added getters and setters for the assertion decrypter and principal decrypter
1 parent c2d8939 commit 5b24049

File tree

2 files changed

+137
-29
lines changed

2 files changed

+137
-29
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

Lines changed: 83 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
* asserting party, IDP, verification certificates.
158158
* </p>
159159
*
160+
* @author Ryan Cassar
160161
* @since 5.2
161162
* @see <a href=
162163
* "https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38">SAML 2
@@ -211,6 +212,32 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
211212

212213
private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
213214

215+
private Consumer<ResponseToken> assertionDecrypter = (responseToken) -> {
216+
List<Assertion> assertions = new ArrayList<>();
217+
for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) {
218+
try {
219+
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
220+
Assertion assertion = decrypter.decrypt(encryptedAssertion);
221+
assertions.add(assertion);
222+
}
223+
catch (DecryptionException ex) {
224+
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
225+
}
226+
}
227+
responseToken.getResponse().getAssertions().addAll(assertions);
228+
};
229+
230+
private Consumer<ResponseToken> principalDecrypter = (responseToken) -> {
231+
try {
232+
Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
233+
Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
234+
assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
235+
}
236+
catch (DecryptionException ex) {
237+
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
238+
}
239+
};
240+
214241
/**
215242
* Creates an {@link OpenSamlAuthenticationProvider}
216243
*/
@@ -332,6 +359,52 @@ public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) {
332359
this.responseTimeValidationSkew = responseTimeValidationSkew;
333360
}
334361

362+
/**
363+
* Sets the assertion response custom decrypter.
364+
*
365+
* You can use this method like so:
366+
*
367+
* <pre>
368+
* YourDecrypter decrypter = // ... your custom decrypter
369+
*
370+
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
371+
* provider.setAssertionDecrypter((responseToken) -> {
372+
* Response response = responseToken.getResponse();
373+
* EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
374+
* Assertion assertion = decrypter.decrypt(encrypted);
375+
* response.getAssertions().add(assertion);
376+
* });
377+
* </pre>
378+
* @param assertionDecrypter response token consumer
379+
*/
380+
public void setAssertionDecrypter(Consumer<ResponseToken> assertionDecrypter) {
381+
Assert.notNull(assertionDecrypter, "Consumer<ResponseToken> required");
382+
this.assertionDecrypter = assertionDecrypter;
383+
}
384+
385+
/**
386+
* Sets the principal custom decrypter.
387+
*
388+
* You can use this method like so:
389+
*
390+
* <pre>
391+
* YourDecrypter decrypter = // ... your custom decrypter
392+
*
393+
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
394+
* provider.setAssertionDecrypter((responseToken) -> {
395+
* Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
396+
* EncryptedID encrypted = assertion.getSubject().getEncryptedID();
397+
* NameID name = decrypter.decrypt(encrypted);
398+
* assertion.getSubject().setNameID(name)
399+
* });
400+
* </pre>
401+
* @param principalDecrypter response token consumer
402+
*/
403+
public void setPrincipalDecrypter(Consumer<ResponseToken> principalDecrypter) {
404+
Assert.notNull(principalDecrypter, "Consumer<ResponseToken> required");
405+
this.principalDecrypter = principalDecrypter;
406+
}
407+
335408
/**
336409
* Construct a default strategy for validating each SAML 2.0 Assertion and associated
337410
* {@link Authentication} token
@@ -429,8 +502,8 @@ private void process(Saml2AuthenticationToken token, Response response) {
429502
boolean responseSigned = response.isSigned();
430503
Saml2ResponseValidatorResult result = validateResponse(token, response);
431504

432-
Decrypter decrypter = this.decrypterConverter.convert(token);
433-
List<Assertion> assertions = decryptAssertions(decrypter, response);
505+
ResponseToken responseToken = new ResponseToken(response, token);
506+
List<Assertion> assertions = decryptAssertions(responseToken);
434507
if (!isSigned(responseSigned, assertions)) {
435508
String description = "Either the response or one of the assertions is unsigned. "
436509
+ "Please either sign the response or all of the assertions.";
@@ -439,7 +512,7 @@ private void process(Saml2AuthenticationToken token, Response response) {
439512
result = result.concat(validateAssertions(token, response));
440513

441514
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
442-
NameID nameId = decryptPrincipal(decrypter, firstAssertion);
515+
NameID nameId = decryptPrincipal(responseToken);
443516
if (nameId == null || nameId.getValue() == null) {
444517
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
445518
"Assertion [" + firstAssertion.getID() + "] is missing a subject");
@@ -511,19 +584,9 @@ private Saml2ResponseValidatorResult validateResponse(Saml2AuthenticationToken t
511584
return Saml2ResponseValidatorResult.failure(errors);
512585
}
513586

514-
private List<Assertion> decryptAssertions(Decrypter decrypter, Response response) {
515-
List<Assertion> assertions = new ArrayList<>();
516-
for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
517-
try {
518-
Assertion assertion = decrypter.decrypt(encryptedAssertion);
519-
assertions.add(assertion);
520-
}
521-
catch (DecryptionException ex) {
522-
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
523-
}
524-
}
525-
response.getAssertions().addAll(assertions);
526-
return response.getAssertions();
587+
private List<Assertion> decryptAssertions(ResponseToken response) {
588+
this.assertionDecrypter.accept(response);
589+
return response.getResponse().getAssertions();
527590
}
528591

529592
private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) {
@@ -567,21 +630,16 @@ private boolean isSigned(boolean responseSigned, List<Assertion> assertions) {
567630
return true;
568631
}
569632

570-
private NameID decryptPrincipal(Decrypter decrypter, Assertion assertion) {
633+
private NameID decryptPrincipal(ResponseToken responseToken) {
634+
Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
571635
if (assertion.getSubject() == null) {
572636
return null;
573637
}
574638
if (assertion.getSubject().getEncryptedID() == null) {
575639
return assertion.getSubject().getNameID();
576640
}
577-
try {
578-
NameID nameId = (NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID());
579-
assertion.getSubject().setNameID(nameId);
580-
return nameId;
581-
}
582-
catch (DecryptionException ex) {
583-
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
584-
}
641+
this.principalDecrypter.accept(responseToken);
642+
return assertion.getSubject().getNameID();
585643
}
586644

587645
private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
5757
import org.springframework.security.saml2.credentials.Saml2X509Credential;
5858
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
59+
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.ResponseToken;
5960
import org.springframework.util.StringUtils;
6061

6162
import static org.assertj.core.api.Assertions.assertThat;
@@ -446,17 +447,15 @@ public void setAssertionValidatorWhenNullThenIllegalArgument() {
446447
public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() {
447448
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
448449
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
449-
OpenSamlAuthenticationProvider.ResponseToken responseToken = new OpenSamlAuthenticationProvider.ResponseToken(
450-
response, token);
450+
ResponseToken responseToken = new ResponseToken(response, token);
451451
Saml2Authentication authentication = OpenSamlAuthenticationProvider
452452
.createDefaultResponseAuthenticationConverter().convert(responseToken);
453453
assertThat(authentication.getName()).isEqualTo("[email protected]");
454454
}
455455

456456
@Test
457457
public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() {
458-
Converter<OpenSamlAuthenticationProvider.ResponseToken, Saml2Authentication> authenticationConverter = mock(
459-
Converter.class);
458+
Converter<ResponseToken, Saml2Authentication> authenticationConverter = mock(Converter.class);
460459
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
461460
provider.setResponseAuthenticationConverter(authenticationConverter);
462461
Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
@@ -473,6 +472,57 @@ public void setResponseAuthenticationConverterWhenNullThenIllegalArgument() {
473472
// @formatter:on
474473
}
475474

475+
@Test
476+
public void setAssertionDecrypterWhenNullThenIllegalArgument() {
477+
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionDecrypter(null));
478+
}
479+
480+
@Test
481+
public void setPrincipalDecrypterWhenNullThenIllegalArgument() {
482+
assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setPrincipalDecrypter(null));
483+
}
484+
485+
@Test
486+
public void setAssertionDecrypterThenChangesAssertion() {
487+
Response response = TestOpenSamlObjects.response();
488+
Assertion assertion = TestOpenSamlObjects.assertion();
489+
assertion.getSubject().getSubjectConfirmations()
490+
.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
491+
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
492+
RELYING_PARTY_ENTITY_ID);
493+
response.getAssertions().add(assertion);
494+
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
495+
this.provider.setAssertionDecrypter(mockAssertionAndPrincipalDecrypter());
496+
assertThatExceptionOfType(Saml2AuthenticationException.class)
497+
.isThrownBy(() -> this.provider.authenticate(token))
498+
.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
499+
assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
500+
}
501+
502+
@Test
503+
public void setPrincipalDecrypterThenChangesAssertion() {
504+
Response response = TestOpenSamlObjects.response();
505+
Assertion assertion = TestOpenSamlObjects.assertion();
506+
assertion.getSubject().getSubjectConfirmations()
507+
.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
508+
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
509+
RELYING_PARTY_ENTITY_ID);
510+
response.getAssertions().add(assertion);
511+
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
512+
this.provider.setPrincipalDecrypter(mockAssertionAndPrincipalDecrypter());
513+
this.provider.authenticate(token);
514+
assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
515+
}
516+
517+
private Consumer<ResponseToken> mockAssertionAndPrincipalDecrypter() {
518+
return (responseToken) -> {
519+
responseToken.getResponse().getAssertions().clear();
520+
responseToken.getResponse().getAssertions()
521+
.add(TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion("1", "2", "3", "4"),
522+
TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID));
523+
};
524+
}
525+
476526
private <T extends XMLObject> T build(QName qName) {
477527
return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
478528
}

0 commit comments

Comments
 (0)