Skip to content

Commit 13e8df9

Browse files
committed
Support SAML 2 attributes
Example code from spring-projects/spring-security#8661
1 parent 9faac54 commit 13e8df9

File tree

3 files changed

+115
-2
lines changed

3 files changed

+115
-2
lines changed
Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
package io.github.danilopiazza.spring.boot.saml.controller;
22

3+
import org.springframework.beans.factory.annotation.Autowired;
34
import org.springframework.http.MediaType;
45
import org.springframework.security.core.annotation.AuthenticationPrincipal;
6+
import org.springframework.security.core.annotation.CurrentSecurityContext;
57
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
8+
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
69
import org.springframework.web.bind.annotation.GetMapping;
710
import org.springframework.web.bind.annotation.RestController;
811

12+
import io.github.danilopiazza.spring.boot.saml.security.Saml2AttributeService;
13+
import io.github.danilopiazza.spring.boot.saml.security.Saml2AuthenticatedPrincipalWithAttributes;
14+
915
@RestController
1016
public class IndexController {
17+
@Autowired
18+
Saml2AttributeService saml2Attributes;
1119

1220
@GetMapping(path = "/", produces = MediaType.APPLICATION_JSON_VALUE)
13-
public Saml2AuthenticatedPrincipal index(@AuthenticationPrincipal Saml2AuthenticatedPrincipal principal) {
14-
return principal;
21+
public Saml2AuthenticatedPrincipalWithAttributes index(
22+
@CurrentSecurityContext(expression = "authentication") Saml2Authentication authentication,
23+
@AuthenticationPrincipal Saml2AuthenticatedPrincipal principal) {
24+
return new Saml2AuthenticatedPrincipalWithAttributes(principal.getName(),
25+
saml2Attributes.getAttributes(authentication));
1526
}
1627
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package io.github.danilopiazza.spring.boot.saml.security;
2+
3+
import java.io.IOException;
4+
import java.io.Reader;
5+
import java.io.StringReader;
6+
import java.util.Collection;
7+
import java.util.List;
8+
import java.util.Map;
9+
import java.util.stream.Collectors;
10+
11+
import org.opensaml.core.xml.XMLObject;
12+
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
13+
import org.opensaml.core.xml.io.UnmarshallingException;
14+
import org.opensaml.core.xml.schema.XSString;
15+
import org.opensaml.core.xml.schema.impl.XSAnyImpl;
16+
import org.opensaml.saml.saml2.core.Response;
17+
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
18+
import org.springframework.stereotype.Service;
19+
import org.w3c.dom.Element;
20+
21+
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
22+
import net.shibboleth.utilities.java.support.xml.BasicParserPool;
23+
import net.shibboleth.utilities.java.support.xml.XMLParserException;
24+
25+
@Service
26+
public class Saml2AttributeService {
27+
public Map<String, List<String>> getAttributes(Saml2Authentication authentication) {
28+
Element element = getDocumentElement(authentication);
29+
Response response = getResponse(element);
30+
return response.getAssertions().stream().flatMap(assertion -> assertion.getAttributeStatements().stream())
31+
.flatMap(attributeStatement -> attributeStatement.getAttributes().stream())
32+
.collect(Collectors.toMap(attribute -> attribute.getName(),
33+
attribute -> getAttributeValues(attribute.getAttributeValues())));
34+
}
35+
36+
private Element getDocumentElement(Saml2Authentication authentication) {
37+
try (Reader reader = new StringReader(authentication.getSaml2Response())) {
38+
BasicParserPool basicParserPool = new BasicParserPool();
39+
basicParserPool.initialize();
40+
return basicParserPool.parse(reader).getDocumentElement();
41+
} catch (ComponentInitializationException | IOException | XMLParserException e) {
42+
throw new IllegalArgumentException(e);
43+
}
44+
}
45+
46+
private Response getResponse(Element element) {
47+
try {
48+
return (Response) XMLObjectProviderRegistrySupport.getUnmarshallerFactory().getUnmarshaller(element)
49+
.unmarshall(element);
50+
} catch (UnmarshallingException e) {
51+
throw new IllegalArgumentException(e);
52+
}
53+
}
54+
55+
private List<String> getAttributeValues(Collection<XMLObject> collection) {
56+
return collection.stream().map(this::getAttributeValue).collect(Collectors.toList());
57+
}
58+
59+
private String getAttributeValue(XMLObject attributeValue) {
60+
return attributeValue == null ? null
61+
: attributeValue instanceof XSString ? getStringAttributeValue((XSString) attributeValue)
62+
: attributeValue instanceof XSAnyImpl ? getAnyAttributeValue((XSAnyImpl) attributeValue)
63+
: attributeValue.toString();
64+
}
65+
66+
private String getStringAttributeValue(XSString attributeValue) {
67+
return attributeValue.getValue();
68+
}
69+
70+
private String getAnyAttributeValue(XSAnyImpl attributeValue) {
71+
return attributeValue.getTextContent();
72+
}
73+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package io.github.danilopiazza.spring.boot.saml.security;
2+
3+
import java.util.List;
4+
import java.util.Map;
5+
6+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
7+
8+
public class Saml2AuthenticatedPrincipalWithAttributes implements Saml2AuthenticatedPrincipal {
9+
private final String name;
10+
private final Map<String, List<String>> attributes;
11+
12+
public Saml2AuthenticatedPrincipalWithAttributes(String name, Map<String, List<String>> attributes) {
13+
this.name = name;
14+
this.attributes = attributes;
15+
}
16+
17+
public String getName() {
18+
return name;
19+
}
20+
21+
public Map<String, List<String>> getAttributes() {
22+
return attributes;
23+
}
24+
25+
@Override
26+
public String toString() {
27+
return "Saml2AuthenticatedPrincipalWithAttributes [attributes=" + attributes + ", name=" + name + "]";
28+
}
29+
}

0 commit comments

Comments
 (0)