Skip to content

Commit 79f1cf5

Browse files
committed
Allow customizing Jwt claims and headers
Closes gh-173
1 parent f97b8b2 commit 79f1cf5

File tree

4 files changed

+127
-44
lines changed

4 files changed

+127
-44
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer;
2727
import org.springframework.security.crypto.key.CryptoKeySource;
2828
import org.springframework.security.oauth2.jose.jws.NimbusJwsEncoder;
29+
import org.springframework.security.oauth2.jwt.JwtEncoder;
2930
import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
3031
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
3132
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
@@ -166,7 +167,7 @@ public void init(B builder) {
166167
getAuthorizationService(builder));
167168
builder.authenticationProvider(postProcess(clientAuthenticationProvider));
168169

169-
NimbusJwsEncoder jwtEncoder = new NimbusJwsEncoder(getKeySource(builder));
170+
JwtEncoder jwtEncoder = getJwtEncoder(builder);
170171

171172
OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider =
172173
new OAuth2AuthorizationCodeAuthenticationProvider(
@@ -253,23 +254,29 @@ public void configure(B builder) {
253254
builder.addFilterAfter(postProcess(tokenRevocationEndpointFilter), OAuth2TokenEndpointFilter.class);
254255
}
255256

257+
private static void validateProviderSettings(ProviderSettings providerSettings) {
258+
if (providerSettings.issuer() != null) {
259+
try {
260+
new URI(providerSettings.issuer()).toURL();
261+
} catch (Exception ex) {
262+
throw new IllegalArgumentException("issuer must be a valid URL", ex);
263+
}
264+
}
265+
}
266+
256267
private static <B extends HttpSecurityBuilder<B>> RegisteredClientRepository getRegisteredClientRepository(B builder) {
257268
RegisteredClientRepository registeredClientRepository = builder.getSharedObject(RegisteredClientRepository.class);
258269
if (registeredClientRepository == null) {
259-
registeredClientRepository = getRegisteredClientRepositoryBean(builder);
270+
registeredClientRepository = getBean(builder, RegisteredClientRepository.class);
260271
builder.setSharedObject(RegisteredClientRepository.class, registeredClientRepository);
261272
}
262273
return registeredClientRepository;
263274
}
264275

265-
private static <B extends HttpSecurityBuilder<B>> RegisteredClientRepository getRegisteredClientRepositoryBean(B builder) {
266-
return builder.getSharedObject(ApplicationContext.class).getBean(RegisteredClientRepository.class);
267-
}
268-
269276
private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizationService getAuthorizationService(B builder) {
270277
OAuth2AuthorizationService authorizationService = builder.getSharedObject(OAuth2AuthorizationService.class);
271278
if (authorizationService == null) {
272-
authorizationService = getAuthorizationServiceBean(builder);
279+
authorizationService = getOptionalBean(builder, OAuth2AuthorizationService.class);
273280
if (authorizationService == null) {
274281
authorizationService = new InMemoryOAuth2AuthorizationService();
275282
}
@@ -278,34 +285,28 @@ private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizationService get
278285
return authorizationService;
279286
}
280287

281-
private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizationService getAuthorizationServiceBean(B builder) {
282-
Map<String, OAuth2AuthorizationService> authorizationServiceMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
283-
builder.getSharedObject(ApplicationContext.class), OAuth2AuthorizationService.class);
284-
if (authorizationServiceMap.size() > 1) {
285-
throw new NoUniqueBeanDefinitionException(OAuth2AuthorizationService.class, authorizationServiceMap.size(),
286-
"Expected single matching bean of type '" + OAuth2AuthorizationService.class.getName() + "' but found " +
287-
authorizationServiceMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(authorizationServiceMap.keySet()));
288+
private static <B extends HttpSecurityBuilder<B>> JwtEncoder getJwtEncoder(B builder) {
289+
JwtEncoder jwtEncoder = getOptionalBean(builder, JwtEncoder.class);
290+
if (jwtEncoder == null) {
291+
CryptoKeySource keySource = getKeySource(builder);
292+
jwtEncoder = new NimbusJwsEncoder(keySource);
288293
}
289-
return (!authorizationServiceMap.isEmpty() ? authorizationServiceMap.values().iterator().next() : null);
294+
return jwtEncoder;
290295
}
291296

292297
private static <B extends HttpSecurityBuilder<B>> CryptoKeySource getKeySource(B builder) {
293298
CryptoKeySource keySource = builder.getSharedObject(CryptoKeySource.class);
294299
if (keySource == null) {
295-
keySource = getKeySourceBean(builder);
300+
keySource = getBean(builder, CryptoKeySource.class);
296301
builder.setSharedObject(CryptoKeySource.class, keySource);
297302
}
298303
return keySource;
299304
}
300305

301-
private static <B extends HttpSecurityBuilder<B>> CryptoKeySource getKeySourceBean(B builder) {
302-
return builder.getSharedObject(ApplicationContext.class).getBean(CryptoKeySource.class);
303-
}
304-
305306
private static <B extends HttpSecurityBuilder<B>> ProviderSettings getProviderSettings(B builder) {
306307
ProviderSettings providerSettings = builder.getSharedObject(ProviderSettings.class);
307308
if (providerSettings == null) {
308-
providerSettings = getProviderSettingsBean(builder);
309+
providerSettings = getOptionalBean(builder, ProviderSettings.class);
309310
if (providerSettings == null) {
310311
providerSettings = new ProviderSettings();
311312
}
@@ -314,24 +315,18 @@ private static <B extends HttpSecurityBuilder<B>> ProviderSettings getProviderSe
314315
return providerSettings;
315316
}
316317

317-
private static <B extends HttpSecurityBuilder<B>> ProviderSettings getProviderSettingsBean(B builder) {
318-
Map<String, ProviderSettings> providerSettingsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
319-
builder.getSharedObject(ApplicationContext.class), ProviderSettings.class);
320-
if (providerSettingsMap.size() > 1) {
321-
throw new NoUniqueBeanDefinitionException(ProviderSettings.class, providerSettingsMap.size(),
322-
"Expected single matching bean of type '" + ProviderSettings.class.getName() + "' but found " +
323-
providerSettingsMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(providerSettingsMap.keySet()));
324-
}
325-
return (!providerSettingsMap.isEmpty() ? providerSettingsMap.values().iterator().next() : null);
318+
private static <B extends HttpSecurityBuilder<B>, T> T getBean(B builder, Class<T> type) {
319+
return builder.getSharedObject(ApplicationContext.class).getBean(type);
326320
}
327321

328-
private void validateProviderSettings(ProviderSettings providerSettings) {
329-
if (providerSettings.issuer() != null) {
330-
try {
331-
new URI(providerSettings.issuer()).toURL();
332-
} catch (Exception ex) {
333-
throw new IllegalArgumentException("issuer must be a valid URL", ex);
334-
}
322+
private static <B extends HttpSecurityBuilder<B>, T> T getOptionalBean(B builder, Class<T> type) {
323+
Map<String, T> beansMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
324+
builder.getSharedObject(ApplicationContext.class), type);
325+
if (beansMap.size() > 1) {
326+
throw new NoUniqueBeanDefinitionException(type, beansMap.size(),
327+
"Expected single matching bean of type '" + type.getName() + "' but found " +
328+
beansMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(beansMap.keySet()));
335329
}
330+
return (!beansMap.isEmpty() ? beansMap.values().iterator().next() : null);
336331
}
337332
}

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoder.java

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import java.util.Map;
5454
import java.util.Set;
5555
import java.util.UUID;
56+
import java.util.function.BiConsumer;
5657
import java.util.stream.Collectors;
5758

5859
/**
@@ -94,6 +95,7 @@ public final class NimbusJwsEncoder implements JwtEncoder {
9495
private static final Converter<JoseHeader, JWSHeader> jwsHeaderConverter = new JwsHeaderConverter();
9596
private static final Converter<JwtClaimsSet, JWTClaimsSet> jwtClaimsSetConverter = new JwtClaimsSetConverter();
9697
private final CryptoKeySource keySource;
98+
private BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = (headers, claims) -> {};
9799

98100
/**
99101
* Constructs a {@code NimbusJwsEncoder} using the provided parameters.
@@ -105,6 +107,19 @@ public NimbusJwsEncoder(CryptoKeySource keySource) {
105107
this.keySource = keySource;
106108
}
107109

110+
/**
111+
* Sets the {@link Jwt} customizer to be provided the
112+
* {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder}
113+
* allowing for further customizations.
114+
*
115+
* @param jwtCustomizer the {@link Jwt} customizer to be provided the
116+
* {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder}
117+
*/
118+
public void setJwtCustomizer(BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer) {
119+
Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
120+
this.jwtCustomizer = jwtCustomizer;
121+
}
122+
108123
@Override
109124
public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException {
110125
Assert.notNull(headers, "headers cannot be null");
@@ -136,15 +151,18 @@ public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingExc
136151
}
137152
}
138153

139-
headers = JoseHeader.from(headers)
154+
JoseHeader.Builder headersBuilder = JoseHeader.from(headers)
140155
.type(JOSEObjectType.JWT.getType())
141-
.keyId(cryptoKey.getId())
142-
.build();
143-
JWSHeader jwsHeader = jwsHeaderConverter.convert(headers);
156+
.keyId(cryptoKey.getId());
157+
JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.from(claims)
158+
.id(UUID.randomUUID().toString());
159+
160+
this.jwtCustomizer.accept(headersBuilder, claimsBuilder);
144161

145-
claims = JwtClaimsSet.from(claims)
146-
.id(UUID.randomUUID().toString())
147-
.build();
162+
headers = headersBuilder.build();
163+
claims = claimsBuilder.build();
164+
165+
JWSHeader jwsHeader = jwsHeaderConverter.convert(headers);
148166
JWTClaimsSet jwtClaimsSet = jwtClaimsSetConverter.convert(claims);
149167

150168
SignedJWT signedJWT = new SignedJWT(jwsHeader, jwtClaimsSet);

oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
3434
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
3535
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
36+
import org.springframework.security.oauth2.jose.JoseHeader;
37+
import org.springframework.security.oauth2.jose.jws.NimbusJwsEncoder;
38+
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
39+
import org.springframework.security.oauth2.jwt.JwtEncoder;
3640
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
3741
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
3842
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
@@ -53,6 +57,7 @@
5357
import java.net.URLEncoder;
5458
import java.nio.charset.StandardCharsets;
5559
import java.util.Base64;
60+
import java.util.function.BiConsumer;
5661

5762
import static org.assertj.core.api.Assertions.assertThat;
5863
import static org.hamcrest.CoreMatchers.containsString;
@@ -86,6 +91,8 @@ public class OAuth2AuthorizationCodeGrantTests {
8691
private static RegisteredClientRepository registeredClientRepository;
8792
private static OAuth2AuthorizationService authorizationService;
8893
private static CryptoKeySource keySource;
94+
private static NimbusJwsEncoder jwtEncoder;
95+
private static BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer;
8996

9097
@Rule
9198
public final SpringTestRule spring = new SpringTestRule();
@@ -98,6 +105,9 @@ public static void init() {
98105
registeredClientRepository = mock(RegisteredClientRepository.class);
99106
authorizationService = mock(OAuth2AuthorizationService.class);
100107
keySource = new StaticKeyGeneratingCryptoKeySource();
108+
jwtEncoder = new NimbusJwsEncoder(keySource);
109+
jwtCustomizer = mock(BiConsumer.class);
110+
jwtEncoder.setJwtCustomizer(jwtCustomizer);
101111
}
102112

103113
@Before
@@ -223,6 +233,28 @@ public void requestWhenPublicClientWithPkceThenReturnAccessTokenResponse() throw
223233
verify(authorizationService, times(2)).save(any());
224234
}
225235

236+
@Test
237+
public void requestWhenCustomJwtEncoderThenUsed() throws Exception {
238+
this.spring.register(AuthorizationServerConfigurationWithJwtEncoder.class).autowire();
239+
240+
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
241+
when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
242+
.thenReturn(registeredClient);
243+
244+
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
245+
when(authorizationService.findByToken(
246+
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
247+
eq(TokenType.AUTHORIZATION_CODE)))
248+
.thenReturn(authorization);
249+
250+
this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI)
251+
.params(getTokenRequestParameters(registeredClient, authorization))
252+
.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
253+
registeredClient.getClientId(), registeredClient.getClientSecret())));
254+
255+
verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class));
256+
}
257+
226258
private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
227259
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
228260
parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
@@ -270,4 +302,14 @@ CryptoKeySource keySource() {
270302
return keySource;
271303
}
272304
}
305+
306+
@EnableWebSecurity
307+
@Import(OAuth2AuthorizationServerConfiguration.class)
308+
static class AuthorizationServerConfigurationWithJwtEncoder extends AuthorizationServerConfiguration {
309+
310+
@Bean
311+
JwtEncoder jwtEncoder() {
312+
return jwtEncoder;
313+
}
314+
}
273315
}

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoderTests.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@
3232
import java.security.interfaces.RSAPublicKey;
3333
import java.util.Collections;
3434
import java.util.LinkedHashSet;
35+
import java.util.function.BiConsumer;
3536
import java.util.stream.Collectors;
3637
import java.util.stream.Stream;
3738

3839
import static org.assertj.core.api.Assertions.assertThatThrownBy;
40+
import static org.mockito.ArgumentMatchers.any;
3941
import static org.mockito.Mockito.mock;
42+
import static org.mockito.Mockito.verify;
4043
import static org.mockito.Mockito.when;
4144

4245
/**
@@ -61,6 +64,13 @@ public void constructorWhenKeySourceNullThenThrowIllegalArgumentException() {
6164
.hasMessage("keySource cannot be null");
6265
}
6366

67+
@Test
68+
public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
69+
assertThatThrownBy(() -> this.jwtEncoder.setJwtCustomizer(null))
70+
.isInstanceOf(IllegalArgumentException.class)
71+
.hasMessage("jwtCustomizer cannot be null");
72+
}
73+
6474
@Test
6575
public void encodeWhenHeadersNullThenThrowIllegalArgumentException() {
6676
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
@@ -128,6 +138,24 @@ public void encodeWhenSuccessThenDecodes() {
128138
jwtDecoder.decode(jws.getTokenValue());
129139
}
130140

141+
@Test
142+
public void encodeWhenCustomizerSetThenCalled() {
143+
AsymmetricKey rsaKey = TestCryptoKeys.rsaKey().build();
144+
when(this.keySource.getKeys()).thenReturn(Collections.singleton(rsaKey));
145+
146+
JoseHeader joseHeader = TestJoseHeaders.joseHeader()
147+
.headers(headers -> headers.remove(JoseHeaderNames.CRIT))
148+
.build();
149+
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
150+
151+
BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = mock(BiConsumer.class);
152+
this.jwtEncoder.setJwtCustomizer(jwtCustomizer);
153+
154+
this.jwtEncoder.encode(joseHeader, jwtClaimsSet);
155+
156+
verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class));
157+
}
158+
131159
@Test
132160
public void encodeWhenMultipleActiveKeysThenUseFirst() {
133161
AsymmetricKey rsaKey1 = TestCryptoKeys.rsaKey().build();

0 commit comments

Comments
 (0)