diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index 2b51e378649..4bbdac1454a 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -15,37 +15,22 @@ */ package org.springframework.security.oauth2.jwt; -import java.security.interfaces.RSAPublicKey; -import java.time.Instant; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -import com.nimbusds.jose.JOSEException; -import com.nimbusds.jose.JWSAlgorithm; -import com.nimbusds.jose.jwk.JWK; -import com.nimbusds.jose.jwk.JWKSelector; -import com.nimbusds.jose.jwk.JWKSet; -import com.nimbusds.jose.jwk.RSAKey; -import com.nimbusds.jose.jwk.source.ImmutableJWKSet; -import com.nimbusds.jose.jwk.source.JWKSource; -import com.nimbusds.jose.proc.BadJOSEException; -import com.nimbusds.jose.proc.JWSKeySelector; -import com.nimbusds.jose.proc.JWSVerificationKeySelector; +import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTParser; import com.nimbusds.jwt.SignedJWT; -import com.nimbusds.jwt.proc.DefaultJWTProcessor; -import com.nimbusds.jwt.proc.JWTProcessor; -import reactor.core.publisher.Mono; - import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; -import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.util.Assert; +import reactor.core.publisher.Mono; + +import java.security.interfaces.RSAPublicKey; +import java.time.Instant; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; /** * An implementation of a {@link ReactiveJwtDecoder} that "decodes" a @@ -64,56 +49,23 @@ * @see JSON Web Key (JWK) * @see Nimbus JOSE + JWT SDK */ -public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { - private final JWTProcessor jwtProcessor; - - private final ReactiveJWKSource reactiveJwkSource; - - private final JWKSelectorFactory jwkSelectorFactory; +public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { + private final ReactiveJWTProcessor jwtProcessor; private OAuth2TokenValidator jwtValidator = JwtValidators.createDefault(); private Converter, Map> claimSetConverter = MappedJwtClaimSetConverter .withDefaults(Collections.emptyMap()); - public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) { - JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256); - - RSAKey rsaKey = rsaKey(publicKey); - JWKSet jwkSet = new JWKSet(rsaKey); - JWKSource jwkSource = new ImmutableJWKSet<>(jwkSet); - JWSKeySelector jwsKeySelector = - new JWSVerificationKeySelector<>(algorithm, jwkSource); - DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); - jwtProcessor.setJWSKeySelector(jwsKeySelector); - jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {}); - - this.jwtProcessor = jwtProcessor; - this.reactiveJwkSource = new ReactiveJWKSourceAdapter(jwkSource); - this.jwkSelectorFactory = new JWKSelectorFactory(algorithm); + public NimbusReactiveJwtDecoder(String jwksUri){ + this(new ReactiveJWKSJWTProcessor(jwksUri)); } - /** - * Constructs a {@code NimbusJwtDecoderJwkSupport} using the provided parameters. - * - * @param jwkSetUrl the JSON Web Key (JWK) Set {@code URL} - */ - public NimbusReactiveJwtDecoder(String jwkSetUrl) { - Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty"); - String jwsAlgorithm = JwsAlgorithms.RS256; - JWSAlgorithm algorithm = JWSAlgorithm.parse(jwsAlgorithm); - JWKSource jwkSource = new JWKContextJWKSource(); - JWSKeySelector jwsKeySelector = - new JWSVerificationKeySelector<>(algorithm, jwkSource); - - DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); - jwtProcessor.setJWSKeySelector(jwsKeySelector); - jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {}); - this.jwtProcessor = jwtProcessor; - - this.reactiveJwkSource = new ReactiveRemoteJWKSource(jwkSetUrl); - - this.jwkSelectorFactory = new JWKSelectorFactory(algorithm); + public NimbusReactiveJwtDecoder(RSAPublicKey publicKey){ + this(new ReactivePublicKeyJWTProcessor(publicKey)); + } + public NimbusReactiveJwtDecoder(ReactiveJWTProcessor jwtProcessor){ + this.jwtProcessor=jwtProcessor; } /** @@ -155,11 +107,8 @@ private JWT parse(String token) { private Mono decode(SignedJWT parsedToken) { try { - JWKSelector selector = this.jwkSelectorFactory - .createSelector(parsedToken.getHeader()); - return this.reactiveJwkSource.get(selector) - .onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e)) - .map(jwkList -> createClaimsSet(parsedToken, jwkList)) + return jwtProcessor.process(parsedToken) + .onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new IllegalStateException("Could not obtain the keys", e)) .map(set -> createJwt(parsedToken, set)) .map(this::validateJwt) .onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e)); @@ -168,15 +117,6 @@ private Mono decode(SignedJWT parsedToken) { } } - private JWTClaimsSet createClaimsSet(JWT parsedToken, List jwkList) { - try { - return this.jwtProcessor.process(parsedToken, new JWKContext(jwkList)); - } - catch (BadJOSEException | JOSEException e) { - throw new JwtException("Failed to validate the token", e); - } - } - private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) { Map headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject()); Map claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims()); @@ -196,9 +136,4 @@ private Jwt validateJwt(Jwt jwt) { return jwt; } - - private static RSAKey rsaKey(RSAPublicKey publicKey) { - return new RSAKey.Builder(publicKey) - .build(); - } } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSJWTProcessor.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSJWTProcessor.java new file mode 100644 index 00000000000..9769edfb9f0 --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSJWTProcessor.java @@ -0,0 +1,96 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jwt; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.JWKSelector; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.BadJOSEException; +import com.nimbusds.jose.proc.JWSKeySelector; +import com.nimbusds.jose.proc.JWSVerificationKeySelector; +import com.nimbusds.jwt.JWT; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.jwt.proc.DefaultJWTProcessor; +import com.nimbusds.jwt.proc.JWTProcessor; +import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; +import org.springframework.util.Assert; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +/** + * A simple implementation of {@link ReactiveJWTProcessor}. + * This implementation is mainly a wrapper around a simple {@link com.nimbusds.jwt.proc.JWTProcessor} + * but with a reactive initialization to get public key from a JWKS endpoint. + */ +public class ReactiveJWKSJWTProcessor implements ReactiveJWTProcessor { + private final JWTProcessor jwtProcessor; + private final JWKSelectorFactory jwkSelectorFactory; + private final ReactiveRemoteJWKSource reactiveJwkSource; + + public ReactiveJWKSJWTProcessor(String jwkSetUrl) { + this(jwkSetUrl, WebClient.create(), JWSAlgorithm.parse(JwsAlgorithms.RS256)); + } + + public ReactiveJWKSJWTProcessor(String jwkSetUrl, WebClient webClient) { + this(jwkSetUrl, webClient, JWSAlgorithm.parse(JwsAlgorithms.RS256)); + } + + public ReactiveJWKSJWTProcessor(String jwkSetUrl, JWSAlgorithm algorithm) { + this(jwkSetUrl, WebClient.create(), algorithm); + } + + public ReactiveJWKSJWTProcessor(String jwkSetUrl, WebClient webClient, JWSAlgorithm algorithm) { + Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty"); + + JWKSource jwkSource = new JWKContextJWKSource(); + JWSKeySelector jwsKeySelector = new JWSVerificationKeySelector<>(algorithm, jwkSource); + + DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); + jwtProcessor.setJWSKeySelector(jwsKeySelector); + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); + + reactiveJwkSource = new ReactiveRemoteJWKSource(webClient, jwkSetUrl); + + jwkSelectorFactory = new JWKSelectorFactory(algorithm); + + this.jwtProcessor = jwtProcessor; + } + + public Mono process(SignedJWT jwt) { + return Mono.defer(() -> { + try { + JWKSelector select = jwkSelectorFactory.createSelector(jwt.getHeader()); + return reactiveJwkSource + .get(select) + .map(JWKContext::new) + .map(context -> createClaimsSet(jwt, context)); + } catch (RuntimeException ex) { + return Mono.error(new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex)); + } + }); + } + + private JWTClaimsSet createClaimsSet(JWT parsedToken, JWKContext context) { + try { + return this.jwtProcessor.process(parsedToken, context); + } catch (BadJOSEException | JOSEException e) { + throw new JwtException("Failed to validate the token", e); + } + } +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWTProcessor.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWTProcessor.java new file mode 100644 index 00000000000..ddd7309d23f --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWTProcessor.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jwt; + +import com.nimbusds.jwt.EncryptedJWT; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; +import reactor.core.publisher.Mono; + +/** + * Interface for parsing and processing JWT token. + * It is somehow a reactive version of {@link com.nimbusds.jwt.proc.JWTProcessor}. + */ +public interface ReactiveJWTProcessor { + default Mono process(SignedJWT jwt) { + throw new JwtException("Signed JWT not supported"); + } + + default Mono process(EncryptedJWT jwt) { + throw new JwtException("Encrypted JWT not supported"); + } +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactivePublicKeyJWTProcessor.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactivePublicKeyJWTProcessor.java new file mode 100644 index 00000000000..06b8ceef398 --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactivePublicKeyJWTProcessor.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jwt; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.source.ImmutableJWKSet; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.BadJOSEException; +import com.nimbusds.jose.proc.JWSKeySelector; +import com.nimbusds.jose.proc.JWSVerificationKeySelector; +import com.nimbusds.jose.proc.SecurityContext; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.jwt.proc.DefaultJWTProcessor; +import com.nimbusds.jwt.proc.JWTProcessor; +import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; +import reactor.core.publisher.Mono; + +import java.security.interfaces.RSAPublicKey; + +/** + * A simple implementation of {@link ReactiveJWTProcessor}. + * This implementation is mainly a wrapper around a simple {@link com.nimbusds.jwt.proc.JWTProcessor}. + * JWT will be validated against a provided public key. + */ +public class ReactivePublicKeyJWTProcessor implements ReactiveJWTProcessor { + private final JWTProcessor jwtProcessor; + + public ReactivePublicKeyJWTProcessor(RSAPublicKey publicKey) { + this(publicKey, JWSAlgorithm.parse(JwsAlgorithms.RS256)); + } + + public ReactivePublicKeyJWTProcessor(RSAPublicKey publicKey, JWSAlgorithm algorithm) { + RSAKey rsaKey = new RSAKey.Builder(publicKey).build(); + JWKSet jwkSet = new JWKSet(rsaKey); + JWKSource jwkSource = new ImmutableJWKSet<>(jwkSet); + JWSKeySelector jwsKeySelector = new JWSVerificationKeySelector<>(algorithm, jwkSource); + + DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); + jwtProcessor.setJWSKeySelector(jwsKeySelector); + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); + + this.jwtProcessor = jwtProcessor; + } + + public Mono process(SignedJWT jwt) { + return Mono.defer(() -> { + try { + return Mono.just(jwtProcessor.process(jwt, null)); + } catch (BadJOSEException | JOSEException e) { + return Mono.error(e); + } + }); + } +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java index 0079c8b8634..df40942392f 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java @@ -40,11 +40,12 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource { */ private final AtomicReference> cachedJWKSet = new AtomicReference<>(Mono.empty()); - private WebClient webClient = WebClient.create(); + private final WebClient webClient; private final String jwkSetURL; - ReactiveRemoteJWKSource(String jwkSetURL) { + ReactiveRemoteJWKSource(WebClient webClient, String jwkSetURL) { + this.webClient = webClient; this.jwkSetURL = jwkSetURL; } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index 14df7c87e89..6cd91a48b24 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -16,32 +16,29 @@ package org.springframework.security.oauth2.jwt; -import java.net.UnknownHostException; -import java.security.KeyFactory; -import java.security.interfaces.RSAPublicKey; -import java.security.spec.X509EncodedKeySpec; -import java.time.Instant; -import java.util.Base64; -import java.util.Collections; -import java.util.Map; - import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.After; import org.junit.Before; import org.junit.Test; - import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; +import java.net.UnknownHostException; +import java.security.KeyFactory; +import java.security.interfaces.RSAPublicKey; +import java.security.spec.X509EncodedKeySpec; +import java.time.Instant; +import java.util.Base64; +import java.util.Collections; +import java.util.Map; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; /** * @author Rob Winch @@ -76,6 +73,8 @@ public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); this.server.enqueue(new MockResponse().setBody(jwkSet)); + + this.decoder = new NimbusReactiveJwtDecoder(this.server.url("/certs").toString()); } @@ -96,7 +95,7 @@ public void decodeWhenInvalidUrl() { @Test public void decodeWhenMessageReadScopeThenSuccess() { - Jwt jwt = this.decoder.decode(this.messageReadToken).block(); + Jwt jwt = (Jwt) this.decoder.decode(this.messageReadToken).block(); assertThat(jwt.getClaims().get("scope")).isEqualTo("message:read"); } @@ -117,7 +116,7 @@ public void decodeWhenRSAPublicKeyThenSuccess() throws Exception { public void decodeWhenIssuedAtThenSuccess() { String withIssuedAt = "eyJraWQiOiJrZXktaWQtMSIsImFsZyI6IlJTMjU2In0.eyJzY29wZSI6IiIsImV4cCI6OTIyMzM3MjAwNjA5NjM3NSwiaWF0IjoxNTI5OTQyNDQ4fQ.LBzAJO-FR-uJDHST61oX4kimuQjz6QMJPW_mvEXRB6A-fMQWpfTQ089eboipAqsb33XnwWth9ELju9HMWLk0FjlWVVzwObh9FcoKelmPNR8mZIlFG-pAYGgSwi8HufyLabXHntFavBiFtqwp_z9clSOFK1RxWvt3lywEbGgtCKve0BXOjfKWiH1qe4QKGixH-NFxidvz8Qd5WbJwyb9tChC6ZKoKPv7Jp-N5KpxkY-O2iUtINvn4xOSactUsvKHgF8ZzZjvJGzG57r606OZXaNtoElQzjAPU5xDGg5liuEJzfBhvqiWCLRmSuZ33qwp3aoBnFgEw0B85gsNe3ggABg"; - Jwt jwt = this.decoder.decode(withIssuedAt).block(); + Jwt jwt = (Jwt) this.decoder.decode(withIssuedAt).block(); assertThat(jwt.getClaims().get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1529942448L)); } @@ -188,7 +187,7 @@ public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() when(claimSetConverter.convert(any(Map.class))).thenReturn(Collections.singletonMap("custom", "value")); - Jwt jwt = this.decoder.decode(this.messageReadToken).block(); + Jwt jwt = (Jwt) this.decoder.decode(this.messageReadToken).block(); assertThat(jwt.getClaims().size()).isEqualTo(1); assertThat(jwt.getClaims().get("custom")).isEqualTo("value"); verify(claimSetConverter).convert(any(Map.class)); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java index 58b20a28363..3d0c6306305 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java @@ -28,6 +28,7 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.web.reactive.function.client.WebClient; import java.util.Collections; import java.util.List; @@ -89,7 +90,7 @@ public class ReactiveRemoteJWKSourceTests { @Before public void setup() { this.server = new MockWebServer(); - this.source = new ReactiveRemoteJWKSource(this.server.url("/").toString()); + this.source = new ReactiveRemoteJWKSource(WebClient.create(), this.server.url("/").toString()); this.server.enqueue(new MockResponse().setBody(this.keys)); this.selector = new JWKSelector(this.matcher);