Skip to content

Commit adf96b4

Browse files
committed
Add OAuth2TokenCustomizer
Closes gh-199
1 parent 3f310ee commit adf96b4

25 files changed

+1115
-268
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java

+35
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
3434
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
3535
import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer;
36+
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
3637
import org.springframework.security.oauth2.jwt.JwtEncoder;
38+
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
3739
import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
3840
import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
3941
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@@ -152,23 +154,33 @@ public void init(B builder) {
152154
builder.authenticationProvider(postProcess(clientAuthenticationProvider));
153155

154156
JwtEncoder jwtEncoder = getJwtEncoder(builder);
157+
OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = getJwtCustomizer(builder);
155158

156159
OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider =
157160
new OAuth2AuthorizationCodeAuthenticationProvider(
158161
getAuthorizationService(builder),
159162
jwtEncoder);
163+
if (jwtCustomizer != null) {
164+
authorizationCodeAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
165+
}
160166
builder.authenticationProvider(postProcess(authorizationCodeAuthenticationProvider));
161167

162168
OAuth2RefreshTokenAuthenticationProvider refreshTokenAuthenticationProvider =
163169
new OAuth2RefreshTokenAuthenticationProvider(
164170
getAuthorizationService(builder),
165171
jwtEncoder);
172+
if (jwtCustomizer != null) {
173+
refreshTokenAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
174+
}
166175
builder.authenticationProvider(postProcess(refreshTokenAuthenticationProvider));
167176

168177
OAuth2ClientCredentialsAuthenticationProvider clientCredentialsAuthenticationProvider =
169178
new OAuth2ClientCredentialsAuthenticationProvider(
170179
getAuthorizationService(builder),
171180
jwtEncoder);
181+
if (jwtCustomizer != null) {
182+
clientCredentialsAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
183+
}
172184
builder.authenticationProvider(postProcess(clientCredentialsAuthenticationProvider));
173185

174186
OAuth2TokenRevocationAuthenticationProvider tokenRevocationAuthenticationProvider =
@@ -314,6 +326,19 @@ private static <B extends HttpSecurityBuilder<B>> JWKSource<SecurityContext> get
314326
return jwkSource;
315327
}
316328

329+
@SuppressWarnings("unchecked")
330+
private static <B extends HttpSecurityBuilder<B>> OAuth2TokenCustomizer<JwtEncodingContext> getJwtCustomizer(B builder) {
331+
OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = builder.getSharedObject(OAuth2TokenCustomizer.class);
332+
if (jwtCustomizer == null) {
333+
ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2TokenCustomizer.class, JwtEncodingContext.class);
334+
jwtCustomizer = getOptionalBean(builder, type);
335+
if (jwtCustomizer != null) {
336+
builder.setSharedObject(OAuth2TokenCustomizer.class, jwtCustomizer);
337+
}
338+
}
339+
return jwtCustomizer;
340+
}
341+
317342
private static <B extends HttpSecurityBuilder<B>> ProviderSettings getProviderSettings(B builder) {
318343
ProviderSettings providerSettings = builder.getSharedObject(ProviderSettings.class);
319344
if (providerSettings == null) {
@@ -353,4 +378,14 @@ private static <B extends HttpSecurityBuilder<B>, T> T getOptionalBean(B builder
353378
}
354379
return (!beansMap.isEmpty() ? beansMap.values().iterator().next() : null);
355380
}
381+
382+
@SuppressWarnings("unchecked")
383+
private static <B extends HttpSecurityBuilder<B>, T> T getOptionalBean(B builder, ResolvableType type) {
384+
ApplicationContext context = builder.getSharedObject(ApplicationContext.class);
385+
String[] names = context.getBeanNamesForType(type);
386+
if (names.length > 1) {
387+
throw new NoUniqueBeanDefinitionException(type, names);
388+
}
389+
return names.length == 1 ? (T) context.getBean(names[0]) : null;
390+
}
356391
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright 2020-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.core.context;
17+
18+
import java.util.Map;
19+
20+
import org.springframework.lang.Nullable;
21+
import org.springframework.util.Assert;
22+
23+
/**
24+
* A facility for holding information associated to a specific context.
25+
*
26+
* @author Joe Grandja
27+
* @since 0.1.0
28+
*/
29+
public interface Context {
30+
31+
@Nullable
32+
<V> V get(Object key);
33+
34+
@Nullable
35+
default <V> V get(Class<V> key) {
36+
Assert.notNull(key, "key cannot be null");
37+
V value = get((Object) key);
38+
return key.isInstance(value) ? value : null;
39+
}
40+
41+
boolean hasKey(Object key);
42+
43+
static Context of(Map<Object, Object> context) {
44+
return new DefaultContext(context);
45+
}
46+
47+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright 2020-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.core.context;
17+
18+
import java.util.Collections;
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
22+
import org.springframework.lang.Nullable;
23+
import org.springframework.util.Assert;
24+
25+
/**
26+
* @author Joe Grandja
27+
* @since 0.1.0
28+
*/
29+
final class DefaultContext implements Context {
30+
private final Map<Object, Object> context;
31+
32+
DefaultContext(Map<Object, Object> context) {
33+
Assert.notNull(context, "context cannot be null");
34+
this.context = Collections.unmodifiableMap(new HashMap<>(context));
35+
}
36+
37+
@SuppressWarnings("unchecked")
38+
@Override
39+
@Nullable
40+
public <V> V get(Object key) {
41+
return hasKey(key) ? (V) this.context.get(key) : null;
42+
}
43+
44+
@Override
45+
public boolean hasKey(Object key) {
46+
Assert.notNull(key, "key cannot be null");
47+
return this.context.containsKey(key);
48+
}
49+
50+
}

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java

+14-37
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
import java.util.Set;
2424
import java.util.UUID;
2525
import java.util.concurrent.ConcurrentHashMap;
26-
import java.util.concurrent.atomic.AtomicReference;
27-
import java.util.function.BiConsumer;
2826
import java.util.stream.Collectors;
2927

3028
import com.nimbusds.jose.JOSEException;
@@ -46,7 +44,6 @@
4644
import com.nimbusds.jwt.SignedJWT;
4745

4846
import org.springframework.core.convert.converter.Converter;
49-
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
5047
import org.springframework.util.Assert;
5148
import org.springframework.util.CollectionUtils;
5249
import org.springframework.util.StringUtils;
@@ -88,9 +85,6 @@ public final class NimbusJwsEncoder implements JwtEncoder {
8885

8986
private final JWKSource<SecurityContext> jwkSource;
9087

91-
private BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = (headers, claims) -> {
92-
};
93-
9488
/**
9589
* Constructs a {@code NimbusJwsEncoder} using the provided parameters.
9690
* @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource}
@@ -100,32 +94,12 @@ public NimbusJwsEncoder(JWKSource<SecurityContext> jwkSource) {
10094
this.jwkSource = jwkSource;
10195
}
10296

103-
/**
104-
* Sets the {@link Jwt} customizer to be provided the {@link JoseHeader.Builder} and
105-
* {@link JwtClaimsSet.Builder} allowing for further customizations.
106-
* @param jwtCustomizer the {@link Jwt} customizer to be provided the
107-
* {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder}
108-
*/
109-
public void setJwtCustomizer(BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer) {
110-
Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
111-
this.jwtCustomizer = jwtCustomizer;
112-
}
113-
11497
@Override
11598
public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException {
11699
Assert.notNull(headers, "headers cannot be null");
117100
Assert.notNull(claims, "claims cannot be null");
118101

119-
// @formatter:off
120-
JoseHeader.Builder headersBuilder = JoseHeader.from(headers)
121-
.type(JOSEObjectType.JWT.getType());
122-
JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.from(claims)
123-
.id(UUID.randomUUID().toString());
124-
// @formatter:on
125-
126-
this.jwtCustomizer.accept(headersBuilder, claimsBuilder);
127-
128-
JWK jwk = selectJwk(headersBuilder);
102+
JWK jwk = selectJwk(headers);
129103
if (jwk == null) {
130104
throw new JwtEncodingException(
131105
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
@@ -135,8 +109,15 @@ else if (!StringUtils.hasText(jwk.getKeyID())) {
135109
"The \"kid\" (key ID) from the selected JWK cannot be empty"));
136110
}
137111

138-
headers = headersBuilder.keyId(jwk.getKeyID()).build();
139-
claims = claimsBuilder.build();
112+
// @formatter:off
113+
headers = JoseHeader.from(headers)
114+
.type(JOSEObjectType.JWT.getType())
115+
.keyId(jwk.getKeyID())
116+
.build();
117+
claims = JwtClaimsSet.from(claims)
118+
.id(UUID.randomUUID().toString())
119+
.build();
120+
// @formatter:on
140121

141122
JWSHeader jwsHeader = JWS_HEADER_CONVERTER.convert(headers);
142123
JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims);
@@ -164,13 +145,9 @@ else if (!StringUtils.hasText(jwk.getKeyID())) {
164145
return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims());
165146
}
166147

167-
private JWK selectJwk(JoseHeader.Builder headersBuilder) {
168-
final AtomicReference<JWSAlgorithm> jwsAlgorithm = new AtomicReference<>();
169-
headersBuilder.headers((h) -> {
170-
JwsAlgorithm jwsAlg = (JwsAlgorithm) h.get(JoseHeaderNames.ALG);
171-
jwsAlgorithm.set(JWSAlgorithm.parse(jwsAlg.getName()));
172-
});
173-
JWSHeader jwsHeader = new JWSHeader(jwsAlgorithm.get());
148+
private JWK selectJwk(JoseHeader headers) {
149+
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(headers.getJwsAlgorithm().getName());
150+
JWSHeader jwsHeader = new JWSHeader(jwsAlgorithm);
174151
JWKSelector jwkSelector = new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader));
175152

176153
List<JWK> jwks;
@@ -184,7 +161,7 @@ private JWK selectJwk(JoseHeader.Builder headersBuilder) {
184161

185162
if (jwks.size() > 1) {
186163
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
187-
"Found multiple JWK signing keys for algorithm '" + jwsAlgorithm.get().getName() + "'"));
164+
"Found multiple JWK signing keys for algorithm '" + jwsAlgorithm.getName() + "'"));
188165
}
189166

190167
return !jwks.isEmpty() ? jwks.get(0) : null;

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 the original author or authors.
2+
* Copyright 2020-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -56,4 +56,9 @@ public interface OAuth2AuthorizationAttributeNames {
5656
*/
5757
String ACCESS_TOKEN_ATTRIBUTES = OAuth2Authorization.class.getName().concat(".ACCESS_TOKEN_ATTRIBUTES");
5858

59+
/**
60+
* The name of the attribute used for the resource owner {@code Principal}.
61+
*/
62+
String PRINCIPAL = OAuth2Authorization.class.getName().concat(".PRINCIPAL");
63+
5964
}

0 commit comments

Comments
 (0)