Skip to content

Commit 5fa1e8e

Browse files
committed
Allow subclassing OAuth2AuthenticationContext
Closes gh-492
1 parent 5982d22 commit 5fa1e8e

File tree

3 files changed

+118
-18
lines changed

3 files changed

+118
-18
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/authentication/OAuth2AuthenticationContext.java

Lines changed: 109 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
*/
1616
package org.springframework.security.oauth2.core.authentication;
1717

18+
import java.util.Collections;
1819
import java.util.HashMap;
1920
import java.util.Map;
21+
import java.util.function.Consumer;
2022

2123
import org.springframework.lang.Nullable;
2224
import org.springframework.security.core.Authentication;
@@ -25,34 +27,41 @@
2527
import org.springframework.util.CollectionUtils;
2628

2729
/**
28-
* A context that holds an {@link Authentication} and (optionally) additional information
29-
* and is used by an {@link OAuth2AuthenticationValidator} when attempting to validate the {@link Authentication}.
30+
* A context that holds an {@link Authentication} and (optionally) additional information.
3031
*
3132
* @author Joe Grandja
3233
* @since 0.2.0
3334
* @see Context
34-
* @see OAuth2AuthenticationValidator
3535
*/
36-
public final class OAuth2AuthenticationContext implements Context {
36+
public class OAuth2AuthenticationContext implements Context {
3737
private final Map<Object, Object> context;
3838

3939
/**
4040
* Constructs an {@code OAuth2AuthenticationContext} using the provided parameters.
4141
*
4242
* @param authentication the {@code Authentication}
4343
* @param context a {@code Map} of additional context information
44+
* @deprecated Use {@link #with(Authentication)} instead
4445
*/
46+
@Deprecated
4547
public OAuth2AuthenticationContext(Authentication authentication, @Nullable Map<Object, Object> context) {
4648
Assert.notNull(authentication, "authentication cannot be null");
47-
this.context = new HashMap<>();
49+
Map<Object, Object> ctx = new HashMap<>();
4850
if (!CollectionUtils.isEmpty(context)) {
49-
this.context.putAll(context);
51+
ctx.putAll(context);
5052
}
51-
this.context.put(Authentication.class, authentication);
53+
ctx.put(Authentication.class, authentication);
54+
this.context = Collections.unmodifiableMap(ctx);
55+
}
56+
57+
protected OAuth2AuthenticationContext(Map<Object, Object> context) {
58+
Assert.notEmpty(context, "context cannot be empty");
59+
Assert.notNull(context.get(Authentication.class), "authentication cannot be null");
60+
this.context = Collections.unmodifiableMap(new HashMap<>(context));
5261
}
5362

5463
/**
55-
* Returns the {@link Authentication} associated to the authentication context.
64+
* Returns the {@link Authentication} associated to the context.
5665
*
5766
* @param <T> the type of the {@code Authentication}
5867
* @return the {@link Authentication}
@@ -63,14 +72,105 @@ public <T extends Authentication> T getAuthentication() {
6372
}
6473

6574
@SuppressWarnings("unchecked")
75+
@Nullable
6676
@Override
6777
public <V> V get(Object key) {
68-
return (V) this.context.get(key);
78+
return hasKey(key) ? (V) this.context.get(key) : null;
6979
}
7080

7181
@Override
7282
public boolean hasKey(Object key) {
83+
Assert.notNull(key, "key cannot be null");
7384
return this.context.containsKey(key);
7485
}
7586

87+
/**
88+
* Constructs a new {@link Builder} with the provided {@link Authentication}.
89+
*
90+
* @param authentication the {@link Authentication}
91+
* @return the {@link Builder}
92+
*/
93+
public static Builder with(Authentication authentication) {
94+
return new Builder(authentication);
95+
}
96+
97+
/**
98+
* A builder for {@link OAuth2AuthenticationContext}.
99+
*/
100+
public static final class Builder extends AbstractBuilder<OAuth2AuthenticationContext, Builder> {
101+
102+
private Builder(Authentication authentication) {
103+
super(authentication);
104+
}
105+
106+
@Override
107+
public OAuth2AuthenticationContext build() {
108+
return new OAuth2AuthenticationContext(getContext());
109+
}
110+
111+
}
112+
113+
/**
114+
* A builder for subclasses of {@link OAuth2AuthenticationContext}.
115+
*
116+
* @param <T> the type of the authentication context
117+
* @param <B> the type of the builder
118+
*/
119+
protected static abstract class AbstractBuilder<T extends OAuth2AuthenticationContext, B extends AbstractBuilder<T, B>> {
120+
private final Map<Object, Object> context = new HashMap<>();
121+
122+
protected AbstractBuilder(Authentication authentication) {
123+
Assert.notNull(authentication, "authentication cannot be null");
124+
put(Authentication.class, authentication);
125+
}
126+
127+
/**
128+
* Associates an attribute.
129+
*
130+
* @param key the key for the attribute
131+
* @param value the value of the attribute
132+
* @return the {@link AbstractBuilder} for further configuration
133+
*/
134+
public B put(Object key, Object value) {
135+
Assert.notNull(key, "key cannot be null");
136+
Assert.notNull(value, "value cannot be null");
137+
getContext().put(key, value);
138+
return getThis();
139+
}
140+
141+
/**
142+
* A {@code Consumer} of the attributes {@code Map}
143+
* allowing the ability to add, replace, or remove.
144+
*
145+
* @param contextConsumer a {@link Consumer} of the attributes {@code Map}
146+
* @return the {@link AbstractBuilder} for further configuration
147+
*/
148+
public B context(Consumer<Map<Object, Object>> contextConsumer) {
149+
contextConsumer.accept(getContext());
150+
return getThis();
151+
}
152+
153+
@SuppressWarnings("unchecked")
154+
protected <V> V get(Object key) {
155+
return (V) getContext().get(key);
156+
}
157+
158+
protected Map<Object, Object> getContext() {
159+
return this.context;
160+
}
161+
162+
@SuppressWarnings("unchecked")
163+
protected final B getThis() {
164+
return (B) this;
165+
}
166+
167+
/**
168+
* Builds a new {@link OAuth2AuthenticationContext}.
169+
*
170+
* @return the {@link OAuth2AuthenticationContext}
171+
*/
172+
public abstract T build();
173+
174+
}
175+
76176
}

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,10 @@ private Authentication authenticateAuthorizationRequest(Authentication authentic
156156
authorizationCodeRequestAuthentication, null);
157157
}
158158

159-
Map<Object, Object> context = new HashMap<>();
160-
context.put(RegisteredClient.class, registeredClient);
161-
OAuth2AuthenticationContext authenticationContext = new OAuth2AuthenticationContext(
162-
authorizationCodeRequestAuthentication, context);
159+
OAuth2AuthenticationContext authenticationContext =
160+
OAuth2AuthenticationContext.with(authorizationCodeRequestAuthentication)
161+
.put(RegisteredClient.class, registeredClient)
162+
.build();
163163

164164
OAuth2AuthenticationValidator redirectUriValidator = resolveAuthenticationValidator(OAuth2ParameterNames.REDIRECT_URI);
165165
redirectUriValidator.validate(authenticationContext);

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProvider.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ public Authentication authenticate(Authentication authentication) throws Authent
9898
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
9999
}
100100

101-
Map<Object, Object> context = new HashMap<>();
102-
context.put(OAuth2Token.class, accessTokenAuthentication.getToken());
103-
context.put(OAuth2Authorization.class, authorization);
104-
OAuth2AuthenticationContext authenticationContext = new OAuth2AuthenticationContext(
105-
userInfoAuthentication, context);
101+
OAuth2AuthenticationContext authenticationContext =
102+
OAuth2AuthenticationContext.with(userInfoAuthentication)
103+
.put(OAuth2Token.class, accessTokenAuthentication.getToken())
104+
.put(OAuth2Authorization.class, authorization)
105+
.build();
106106

107107
OidcUserInfo userInfo = this.userInfoMapper.apply(authenticationContext);
108108

0 commit comments

Comments
 (0)