From 943f1da60755eaf652276b7263c6299554726253 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 16 Nov 2020 19:56:45 -0500 Subject: [PATCH 1/4] Add Jwt Client Authentication support Closes gh-8175 --- ...stractOAuth2AuthorizationGrantRequest.java | 34 +- .../oauth2/client/endpoint/JoseHeader.java | 391 ++++++++++++++++++ .../client/endpoint/JoseHeaderNames.java | 127 ++++++ .../client/endpoint/JwsHeaderConverter.java | 144 +++++++ .../oauth2/client/endpoint/JwtClaimsSet.java | 222 ++++++++++ .../endpoint/JwtClaimsSetConverter.java | 109 +++++ .../oauth2/client/endpoint/JwtEncoder.java | 77 ++++ .../client/endpoint/JwtEncodingException.java | 62 +++ .../client/endpoint/NimbusJwsEncoder.java | 213 ++++++++++ ...mbusJwtClientAuthenticationCustomizer.java | 216 ++++++++++ .../OAuth2AuthorizationCodeGrantRequest.java | 16 +- ...zationCodeGrantRequestEntityConverter.java | 46 ++- ...horizationGrantRequestEntityConverter.java | 61 +++ ...2AuthorizationGrantRequestEntityUtils.java | 6 +- .../OAuth2ClientCredentialsGrantRequest.java | 16 +- ...redentialsGrantRequestEntityConverter.java | 42 +- .../endpoint/OAuth2PasswordGrantRequest.java | 16 +- ...h2PasswordGrantRequestEntityConverter.java | 46 ++- .../OAuth2RefreshTokenGrantRequest.java | 16 +- ...freshTokenGrantRequestEntityConverter.java | 45 +- ...orizationCodeTokenResponseClientTests.java | 160 ++++--- ...ntCredentialsTokenResponseClientTests.java | 140 +++++-- ...faultPasswordTokenResponseClientTests.java | 115 +++++- ...tRefreshTokenTokenResponseClientTests.java | 119 +++++- .../client/endpoint/JoseHeaderTests.java | 123 ++++++ .../client/endpoint/JwtClaimsSetTests.java | 105 +++++ .../endpoint/NimbusJwsEncoderTests.java | 347 ++++++++++++++++ ...wtClientAuthenticationCustomizerTests.java | 216 ++++++++++ ...nCodeGrantRequestEntityConverterTests.java | 105 +++-- ...tialsGrantRequestEntityConverterTests.java | 53 ++- ...swordGrantRequestEntityConverterTests.java | 47 ++- ...TokenGrantRequestEntityConverterTests.java | 47 ++- .../client/endpoint/TestJoseHeaders.java | 76 ++++ .../client/endpoint/TestJwtClaimsSets.java | 64 +++ .../core/ClientAuthenticationMethod.java | 13 +- .../core/endpoint/OAuth2ParameterNames.java | 14 +- .../core/ClientAuthenticationMethodTests.java | 12 +- .../security/oauth2/jose/TestJwks.java | 86 ++++ .../security/oauth2/jose/TestKeys.java | 37 +- 39 files changed, 3451 insertions(+), 333 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderNames.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwsHeaderConverter.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSet.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetConverter.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncoder.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncodingException.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizer.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityConverter.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoderTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizerTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJoseHeaders.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJwtClaimsSets.java create mode 100644 oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequest.java index a0d5a698d41..8016896b321 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client.endpoint; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; @@ -27,6 +28,7 @@ * @author Joe Grandja * @since 5.0 * @see AuthorizationGrantType + * @see ClientRegistration * @see Section * 1.3 Authorization Grant */ @@ -34,13 +36,34 @@ public abstract class AbstractOAuth2AuthorizationGrantRequest { private final AuthorizationGrantType authorizationGrantType; + private final ClientRegistration clientRegistration; + /** * Sub-class constructor. * @param authorizationGrantType the authorization grant type + * @deprecated Use + * {@link #AbstractOAuth2AuthorizationGrantRequest(AuthorizationGrantType, ClientRegistration)} + * instead */ + @Deprecated protected AbstractOAuth2AuthorizationGrantRequest(AuthorizationGrantType authorizationGrantType) { Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null"); this.authorizationGrantType = authorizationGrantType; + this.clientRegistration = null; + } + + /** + * Sub-class constructor. + * @param authorizationGrantType the authorization grant type + * @param clientRegistration the client registration + * @since 5.5 + */ + protected AbstractOAuth2AuthorizationGrantRequest(AuthorizationGrantType authorizationGrantType, + ClientRegistration clientRegistration) { + Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null"); + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + this.authorizationGrantType = authorizationGrantType; + this.clientRegistration = clientRegistration; } /** @@ -51,4 +74,13 @@ public AuthorizationGrantType getGrantType() { return this.authorizationGrantType; } + /** + * Returns the {@link ClientRegistration client registration}. + * @return the {@link ClientRegistration} + * @since 5.5 + */ + public ClientRegistration getClientRegistration() { + return this.clientRegistration; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java new file mode 100644 index 00000000000..e2d01b52ddf --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java @@ -0,0 +1,391 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.net.URL; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import org.springframework.security.oauth2.core.converter.ClaimConversionService; +import org.springframework.security.oauth2.jose.JwaAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.Assert; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * The JOSE header is a JSON object representing the header parameters of a JSON Web + * Token, whether the JWT is a JWS or JWE, that describe the cryptographic operations + * applied to the JWT and optionally, additional properties of the JWT. + * + * @author Anoop Garlapati + * @author Joe Grandja + * @since 5.5 + * @see Jwt + * @see JWT JOSE + * Header + * @see JWS JOSE + * Header + * @see JWE JOSE + * Header + */ +final class JoseHeader { + + private final Map headers; + + private JoseHeader(Map headers) { + this.headers = Collections.unmodifiableMap(new HashMap<>(headers)); + } + + /** + * Returns the {@link JwaAlgorithm JWA algorithm} used to digitally sign the JWS or + * encrypt the JWE. + * @return the {@link JwaAlgorithm} + */ + @SuppressWarnings("unchecked") + T getAlgorithm() { + return (T) getHeader(JoseHeaderNames.ALG); + } + + /** + * Returns the JWK Set URL that refers to the resource of a set of JSON-encoded public + * keys, one of which corresponds to the key used to digitally sign the JWS or encrypt + * the JWE. + * @return the JWK Set URL + */ + URL getJwkSetUri() { + return getHeader(JoseHeaderNames.JKU); + } + + /** + * Returns the JSON Web Key which is the public key that corresponds to the key used + * to digitally sign the JWS or encrypt the JWE. + * @return the JSON Web Key + */ + Map getJwk() { + return getHeader(JoseHeaderNames.JWK); + } + + /** + * Returns the key ID that is a hint indicating which key was used to secure the JWS + * or JWE. + * @return the key ID + */ + String getKeyId() { + return getHeader(JoseHeaderNames.KID); + } + + /** + * Returns the X.509 URL that refers to the resource for the X.509 public key + * certificate or certificate chain corresponding to the key used to digitally sign + * the JWS or encrypt the JWE. + * @return the X.509 URL + */ + URL getX509Uri() { + return getHeader(JoseHeaderNames.X5U); + } + + /** + * Returns the X.509 certificate chain that contains the X.509 public key certificate + * or certificate chain corresponding to the key used to digitally sign the JWS or + * encrypt the JWE. + * @return the X.509 certificate chain + */ + List getX509CertificateChain() { + return getHeader(JoseHeaderNames.X5C); + } + + /** + * Returns the X.509 certificate SHA-1 thumbprint that is a base64url-encoded SHA-1 + * thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate + * corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * @return the X.509 certificate SHA-1 thumbprint + */ + String getX509SHA1Thumbprint() { + return getHeader(JoseHeaderNames.X5T); + } + + /** + * Returns the X.509 certificate SHA-256 thumbprint that is a base64url-encoded + * SHA-256 thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate + * corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * @return the X.509 certificate SHA-256 thumbprint + */ + String getX509SHA256Thumbprint() { + return getHeader(JoseHeaderNames.X5T_S256); + } + + /** + * Returns the type header that declares the media type of the JWS/JWE. + * @return the type header + */ + String getType() { + return getHeader(JoseHeaderNames.TYP); + } + + /** + * Returns the content type header that declares the media type of the secured content + * (the payload). + * @return the content type header + */ + String getContentType() { + return getHeader(JoseHeaderNames.CTY); + } + + /** + * Returns the critical headers that indicates which extensions to the JWS/JWE/JWA + * specifications are being used that MUST be understood and processed. + * @return the critical headers + */ + Set getCritical() { + return getHeader(JoseHeaderNames.CRIT); + } + + /** + * Returns the headers. + * @return the headers + */ + Map getHeaders() { + return this.headers; + } + + /** + * Returns the header value. + * @param name the header name + * @param the type of the header value + * @return the header value + */ + @SuppressWarnings("unchecked") + T getHeader(String name) { + Assert.hasText(name, "name cannot be empty"); + return (T) getHeaders().get(name); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@link JwaAlgorithm}. + * @param jwaAlgorithm the {@link JwaAlgorithm} + * @return the {@link Builder} + */ + static Builder withAlgorithm(JwaAlgorithm jwaAlgorithm) { + return new Builder(jwaAlgorithm); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@code headers}. + * @param headers the headers + * @return the {@link Builder} + */ + static Builder from(JoseHeader headers) { + return new Builder(headers); + } + + /** + * A builder for {@link JoseHeader}. + */ + static final class Builder { + + final Map headers = new HashMap<>(); + + private Builder(JwaAlgorithm jwaAlgorithm) { + algorithm(jwaAlgorithm); + } + + private Builder(JoseHeader headers) { + Assert.notNull(headers, "headers cannot be null"); + this.headers.putAll(headers.getHeaders()); + } + + /** + * Sets the {@link JwaAlgorithm JWA algorithm} used to digitally sign the JWS or + * encrypt the JWE. + * @param jwaAlgorithm the {@link JwaAlgorithm} + * @return the {@link Builder} + */ + Builder algorithm(JwaAlgorithm jwaAlgorithm) { + Assert.notNull(jwaAlgorithm, "jwaAlgorithm cannot be null"); + return header(JoseHeaderNames.ALG, jwaAlgorithm); + } + + /** + * Sets the JWK Set URL that refers to the resource of a set of JSON-encoded + * public keys, one of which corresponds to the key used to digitally sign the JWS + * or encrypt the JWE. + * @param jwkSetUri the JWK Set URL + * @return the {@link Builder} + */ + Builder jwkSetUri(String jwkSetUri) { + return header(JoseHeaderNames.JKU, jwkSetUri); + } + + /** + * Sets the JSON Web Key which is the public key that corresponds to the key used + * to digitally sign the JWS or encrypt the JWE. + * @param jwk the JSON Web Key + * @return the {@link Builder} + */ + Builder jwk(Map jwk) { + return header(JoseHeaderNames.JWK, jwk); + } + + /** + * Sets the key ID that is a hint indicating which key was used to secure the JWS + * or JWE. + * @param keyId the key ID + * @return the {@link Builder} + */ + Builder keyId(String keyId) { + return header(JoseHeaderNames.KID, keyId); + } + + /** + * Sets the X.509 URL that refers to the resource for the X.509 public key + * certificate or certificate chain corresponding to the key used to digitally + * sign the JWS or encrypt the JWE. + * @param x509Uri the X.509 URL + * @return the {@link Builder} + */ + Builder x509Uri(String x509Uri) { + return header(JoseHeaderNames.X5U, x509Uri); + } + + /** + * Sets the X.509 certificate chain that contains the X.509 public key certificate + * or certificate chain corresponding to the key used to digitally sign the JWS or + * encrypt the JWE. + * @param x509CertificateChain the X.509 certificate chain + * @return the {@link Builder} + */ + Builder x509CertificateChain(List x509CertificateChain) { + return header(JoseHeaderNames.X5C, x509CertificateChain); + } + + /** + * Sets the X.509 certificate SHA-1 thumbprint that is a base64url-encoded SHA-1 + * thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate + * corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * @param x509SHA1Thumbprint the X.509 certificate SHA-1 thumbprint + * @return the {@link Builder} + */ + Builder x509SHA1Thumbprint(String x509SHA1Thumbprint) { + return header(JoseHeaderNames.X5T, x509SHA1Thumbprint); + } + + /** + * Sets the X.509 certificate SHA-256 thumbprint that is a base64url-encoded + * SHA-256 thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate + * corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * @param x509SHA256Thumbprint the X.509 certificate SHA-256 thumbprint + * @return the {@link Builder} + */ + Builder x509SHA256Thumbprint(String x509SHA256Thumbprint) { + return header(JoseHeaderNames.X5T_S256, x509SHA256Thumbprint); + } + + /** + * Sets the type header that declares the media type of the JWS/JWE. + * @param type the type header + * @return the {@link Builder} + */ + Builder type(String type) { + return header(JoseHeaderNames.TYP, type); + } + + /** + * Sets the content type header that declares the media type of the secured + * content (the payload). + * @param contentType the content type header + * @return the {@link Builder} + */ + Builder contentType(String contentType) { + return header(JoseHeaderNames.CTY, contentType); + } + + /** + * Sets the critical headers that indicates which extensions to the JWS/JWE/JWA + * specifications are being used that MUST be understood and processed. + * @param headerNames the critical header names + * @return the {@link Builder} + */ + Builder critical(Set headerNames) { + return header(JoseHeaderNames.CRIT, headerNames); + } + + /** + * Sets the header. + * @param name the header name + * @param value the header value + * @return the {@link Builder} + */ + Builder header(String name, Object value) { + Assert.hasText(name, "name cannot be empty"); + Assert.notNull(value, "value cannot be null"); + this.headers.put(name, value); + return this; + } + + /** + * A {@code Consumer} to be provided access to the headers allowing the ability to + * add, replace, or remove. + * @param headersConsumer a {@code Consumer} of the headers + * @return the {@link Builder} + */ + Builder headers(Consumer> headersConsumer) { + headersConsumer.accept(this.headers); + return this; + } + + /** + * Builds a new {@link JoseHeader}. + * @return a {@link JoseHeader} + */ + JoseHeader build() { + Assert.notEmpty(this.headers, "headers cannot be empty"); + convertAsURL(JoseHeaderNames.JKU); + convertAsURL(JoseHeaderNames.X5U); + return new JoseHeader(this.headers); + } + + private void convertAsURL(String header) { + Object value = this.headers.get(header); + if (value != null) { + URL convertedValue = ClaimConversionService.getSharedInstance().convert(value, URL.class); + Assert.isTrue(convertedValue != null, + () -> "Unable to convert header '" + header + "' of type '" + value.getClass() + "' to URL."); + this.headers.put(header, convertedValue); + } + } + + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderNames.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderNames.java new file mode 100644 index 00000000000..41abd7eeba0 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderNames.java @@ -0,0 +1,127 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * The Registered Header Parameter Names defined by the JSON Web Token (JWT), JSON Web + * Signature (JWS) and JSON Web Encryption (JWE) specifications that may be contained in + * the JOSE Header of a JWT. + * + * @author Anoop Garlapati + * @author Joe Grandja + * @since 5.5 + * @see JoseHeader + * @see JWT JOSE + * Header + * @see JWS JOSE + * Header + * @see JWE JOSE + * Header + */ +final class JoseHeaderNames { + + /** + * {@code alg} - the algorithm header identifies the cryptographic algorithm used to + * secure a JWS or JWE + */ + static final String ALG = "alg"; + + /** + * {@code jku} - the JWK Set URL header is a URI that refers to a resource for a set + * of JSON-encoded public keys, one of which corresponds to the key used to digitally + * sign a JWS or encrypt a JWE + */ + static final String JKU = "jku"; + + /** + * {@code jwk} - the JSON Web Key header is the public key that corresponds to the key + * used to digitally sign a JWS or encrypt a JWE + */ + static final String JWK = "jwk"; + + /** + * {@code kid} - the key ID header is a hint indicating which key was used to secure a + * JWS or JWE + */ + static final String KID = "kid"; + + /** + * {@code x5u} - the X.509 URL header is a URI that refers to a resource for the X.509 + * public key certificate or certificate chain corresponding to the key used to + * digitally sign a JWS or encrypt a JWE + */ + static final String X5U = "x5u"; + + /** + * {@code x5c} - the X.509 certificate chain header contains the X.509 public key + * certificate or certificate chain corresponding to the key used to digitally sign a + * JWS or encrypt a JWE + */ + static final String X5C = "x5c"; + + /** + * {@code x5t} - the X.509 certificate SHA-1 thumbprint header is a base64url-encoded + * SHA-1 thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate + * corresponding to the key used to digitally sign a JWS or encrypt a JWE + */ + static final String X5T = "x5t"; + + /** + * {@code x5t#S256} - the X.509 certificate SHA-256 thumbprint header is a + * base64url-encoded SHA-256 thumbprint (a.k.a. digest) of the DER encoding of the + * X.509 certificate corresponding to the key used to digitally sign a JWS or encrypt + * a JWE + */ + static final String X5T_S256 = "x5t#S256"; + + /** + * {@code typ} - the type header is used by JWS/JWE applications to declare the media + * type of a JWS/JWE + */ + static final String TYP = "typ"; + + /** + * {@code cty} - the content type header is used by JWS/JWE applications to declare + * the media type of the secured content (the payload) + */ + static final String CTY = "cty"; + + /** + * {@code crit} - the critical header indicates that extensions to the JWS/JWE/JWA + * specifications are being used that MUST be understood and processed + */ + static final String CRIT = "crit"; + + private JoseHeaderNames() { + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwsHeaderConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwsHeaderConverter.java new file mode 100644 index 00000000000..da3ccb32bc6 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwsHeaderConverter.java @@ -0,0 +1,144 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.net.URL; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import com.nimbusds.jose.JOSEObjectType; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.util.Base64; +import com.nimbusds.jose.util.Base64URL; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * A {@link Converter} that converts a {@link JoseHeader} to + * {@code com.nimbusds.jose.JWSHeader}. + * + * @author Joe Grandja + * @since 5.5 + * @see Converter + * @see JoseHeader + * @see com.nimbusds.jose.JWSHeader + */ +final class JwsHeaderConverter implements Converter { + + @Override + public JWSHeader convert(JoseHeader headers) { + JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.parse(headers.getAlgorithm().getName())); + + URL jwkSetUri = headers.getJwkSetUri(); + if (jwkSetUri != null) { + try { + builder.jwkURL(jwkSetUri.toURI()); + } + catch (Exception ex) { + throw new IllegalArgumentException( + "Unable to convert '" + JoseHeaderNames.JKU + "' JOSE header to a URI", ex); + } + } + + Map jwk = headers.getJwk(); + if (!CollectionUtils.isEmpty(jwk)) { + try { + builder.jwk(JWK.parse(jwk)); + } + catch (Exception ex) { + throw new IllegalArgumentException("Unable to convert '" + JoseHeaderNames.JWK + "' JOSE header", ex); + } + } + + String keyId = headers.getKeyId(); + if (StringUtils.hasText(keyId)) { + builder.keyID(keyId); + } + + URL x509Uri = headers.getX509Uri(); + if (x509Uri != null) { + try { + builder.x509CertURL(x509Uri.toURI()); + } + catch (Exception ex) { + throw new IllegalArgumentException( + "Unable to convert '" + JoseHeaderNames.X5U + "' JOSE header to a URI", ex); + } + } + + List x509CertificateChain = headers.getX509CertificateChain(); + if (!CollectionUtils.isEmpty(x509CertificateChain)) { + builder.x509CertChain(x509CertificateChain.stream().map(Base64::new).collect(Collectors.toList())); + } + + String x509SHA1Thumbprint = headers.getX509SHA1Thumbprint(); + if (StringUtils.hasText(x509SHA1Thumbprint)) { + builder.x509CertThumbprint(new Base64URL(x509SHA1Thumbprint)); + } + + String x509SHA256Thumbprint = headers.getX509SHA256Thumbprint(); + if (StringUtils.hasText(x509SHA256Thumbprint)) { + builder.x509CertSHA256Thumbprint(new Base64URL(x509SHA256Thumbprint)); + } + + String type = headers.getType(); + if (StringUtils.hasText(type)) { + builder.type(new JOSEObjectType(type)); + } + + String contentType = headers.getContentType(); + if (StringUtils.hasText(contentType)) { + builder.contentType(contentType); + } + + Set critical = headers.getCritical(); + if (!CollectionUtils.isEmpty(critical)) { + builder.criticalParams(critical); + } + + Map customHeaders = headers.getHeaders().entrySet().stream() + .filter((header) -> !JWSHeader.getRegisteredParameterNames().contains(header.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + if (!CollectionUtils.isEmpty(customHeaders)) { + builder.customParams(customHeaders); + } + + return builder.build(); + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSet.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSet.java new file mode 100644 index 00000000000..d383c04b661 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSet.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.net.URL; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.security.oauth2.core.converter.ClaimConversionService; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimAccessor; +import org.springframework.security.oauth2.jwt.JwtClaimNames; +import org.springframework.util.Assert; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * The {@link Jwt JWT} Claims Set is a JSON object representing the claims conveyed by a + * JSON Web Token. + * + * @author Anoop Garlapati + * @author Joe Grandja + * @since 5.5 + * @see Jwt + * @see JwtClaimAccessor + * @see JWT Claims + * Set + */ +final class JwtClaimsSet implements JwtClaimAccessor { + + private final Map claims; + + private JwtClaimsSet(Map claims) { + this.claims = Collections.unmodifiableMap(new HashMap<>(claims)); + } + + @Override + public Map getClaims() { + return this.claims; + } + + /** + * Returns a new {@link Builder}. + * @return the {@link Builder} + */ + static Builder builder() { + return new Builder(); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@code claims}. + * @param claims a JWT claims set + * @return the {@link Builder} + */ + static Builder from(JwtClaimsSet claims) { + return new Builder(claims); + } + + /** + * A builder for {@link JwtClaimsSet}. + */ + static final class Builder { + + final Map claims = new HashMap<>(); + + private Builder() { + } + + private Builder(JwtClaimsSet claims) { + Assert.notNull(claims, "claims cannot be null"); + this.claims.putAll(claims.getClaims()); + } + + /** + * Sets the issuer {@code (iss)} claim, which identifies the principal that issued + * the JWT. + * @param issuer the issuer identifier + * @return the {@link Builder} + */ + Builder issuer(String issuer) { + return claim(JwtClaimNames.ISS, issuer); + } + + /** + * Sets the subject {@code (sub)} claim, which identifies the principal that is + * the subject of the JWT. + * @param subject the subject identifier + * @return the {@link Builder} + */ + Builder subject(String subject) { + return claim(JwtClaimNames.SUB, subject); + } + + /** + * Sets the audience {@code (aud)} claim, which identifies the recipient(s) that + * the JWT is intended for. + * @param audience the audience that this JWT is intended for + * @return the {@link Builder} + */ + Builder audience(List audience) { + return claim(JwtClaimNames.AUD, audience); + } + + /** + * Sets the expiration time {@code (exp)} claim, which identifies the time on or + * after which the JWT MUST NOT be accepted for processing. + * @param expiresAt the time on or after which the JWT MUST NOT be accepted for + * processing + * @return the {@link Builder} + */ + Builder expiresAt(Instant expiresAt) { + return claim(JwtClaimNames.EXP, expiresAt); + } + + /** + * Sets the not before {@code (nbf)} claim, which identifies the time before which + * the JWT MUST NOT be accepted for processing. + * @param notBefore the time before which the JWT MUST NOT be accepted for + * processing + * @return the {@link Builder} + */ + Builder notBefore(Instant notBefore) { + return claim(JwtClaimNames.NBF, notBefore); + } + + /** + * Sets the issued at {@code (iat)} claim, which identifies the time at which the + * JWT was issued. + * @param issuedAt the time at which the JWT was issued + * @return the {@link Builder} + */ + Builder issuedAt(Instant issuedAt) { + return claim(JwtClaimNames.IAT, issuedAt); + } + + /** + * Sets the JWT ID {@code (jti)} claim, which provides a unique identifier for the + * JWT. + * @param jti the unique identifier for the JWT + * @return the {@link Builder} + */ + Builder id(String jti) { + return claim(JwtClaimNames.JTI, jti); + } + + /** + * Sets the claim. + * @param name the claim name + * @param value the claim value + * @return the {@link Builder} + */ + Builder claim(String name, Object value) { + Assert.hasText(name, "name cannot be empty"); + Assert.notNull(value, "value cannot be null"); + this.claims.put(name, value); + return this; + } + + /** + * A {@code Consumer} to be provided access to the claims allowing the ability to + * add, replace, or remove. + * @param claimsConsumer a {@code Consumer} of the claims + */ + Builder claims(Consumer> claimsConsumer) { + claimsConsumer.accept(this.claims); + return this; + } + + /** + * Builds a new {@link JwtClaimsSet}. + * @return a {@link JwtClaimsSet} + */ + JwtClaimsSet build() { + Assert.notEmpty(this.claims, "claims cannot be empty"); + + // The value of the 'iss' claim is a String or URL (StringOrURI). + // Attempt to convert to URL. + Object issuer = this.claims.get(JwtClaimNames.ISS); + if (issuer != null) { + URL convertedValue = ClaimConversionService.getSharedInstance().convert(issuer, URL.class); + if (convertedValue != null) { + this.claims.put(JwtClaimNames.ISS, convertedValue); + } + } + + return new JwtClaimsSet(this.claims); + } + + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetConverter.java new file mode 100644 index 00000000000..58cdd51fd28 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetConverter.java @@ -0,0 +1,109 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.time.Instant; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import com.nimbusds.jwt.JWTClaimsSet; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.jwt.JwtClaimNames; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * A {@link Converter} that converts a {@link JwtClaimsSet} to + * {@code com.nimbusds.jwt.JWTClaimsSet}. + * + * @author Joe Grandja + * @since 5.5 + * @see Converter + * @see JwtClaimsSet + * @see com.nimbusds.jwt.JWTClaimsSet + */ +final class JwtClaimsSetConverter implements Converter { + + @Override + public JWTClaimsSet convert(JwtClaimsSet claims) { + JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder(); + + // NOTE: The value of the 'iss' claim is a String or URL (StringOrURI). + Object issuer = claims.getClaim(JwtClaimNames.ISS); + if (issuer != null) { + builder.issuer(issuer.toString()); + } + + String subject = claims.getSubject(); + if (StringUtils.hasText(subject)) { + builder.subject(subject); + } + + List audience = claims.getAudience(); + if (!CollectionUtils.isEmpty(audience)) { + builder.audience(audience); + } + + Instant expiresAt = claims.getExpiresAt(); + if (expiresAt != null) { + builder.expirationTime(Date.from(expiresAt)); + } + + Instant notBefore = claims.getNotBefore(); + if (notBefore != null) { + builder.notBeforeTime(Date.from(notBefore)); + } + + Instant issuedAt = claims.getIssuedAt(); + if (issuedAt != null) { + builder.issueTime(Date.from(issuedAt)); + } + + String jwtId = claims.getId(); + if (StringUtils.hasText(jwtId)) { + builder.jwtID(jwtId); + } + + Map customClaims = claims.getClaims().entrySet().stream() + .filter((claim) -> !JWTClaimsSet.getRegisteredNames().contains(claim.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + if (!CollectionUtils.isEmpty(customClaims)) { + customClaims.forEach(builder::claim); + } + + return builder.build(); + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncoder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncoder.java new file mode 100644 index 00000000000..fc20abf626e --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncoder.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * Implementations of this interface are responsible for encoding a JSON Web Token (JWT) + * to it's compact claims representation format. + * + *

+ * JWTs may be represented using the JWS Compact Serialization format for a JSON Web + * Signature (JWS) structure or JWE Compact Serialization format for a JSON Web Encryption + * (JWE) structure. Therefore, implementors are responsible for signing a JWS and/or + * encrypting a JWE. + * + * @author Anoop Garlapati + * @author Joe Grandja + * @since 5.5 + * @see Jwt + * @see JoseHeader + * @see JwtClaimsSet + * @see JwtDecoder + * @see JSON Web Token + * (JWT) + * @see JSON Web Signature + * (JWS) + * @see JSON Web Encryption + * (JWE) + * @see JWS + * Compact Serialization + * @see JWE + * Compact Serialization + */ +@FunctionalInterface +interface JwtEncoder { + + /** + * Encode the JWT to it's compact claims representation format. + * @param headers the JOSE header + * @param claims the JWT Claims Set + * @return a {@link Jwt} + * @throws JwtEncodingException if an error occurs while attempting to encode the JWT + */ + Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException; + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncodingException.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncodingException.java new file mode 100644 index 00000000000..53c82b13bd5 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncodingException.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.jwt.JwtException; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * This exception is thrown when an error occurs while attempting to encode a JSON Web + * Token (JWT). + * + * @author Joe Grandja + * @since 5.5 + */ +class JwtEncodingException extends JwtException { + + /** + * Constructs a {@code JwtEncodingException} using the provided parameters. + * @param message the detail message + */ + JwtEncodingException(String message) { + super(message); + } + + /** + * Constructs a {@code JwtEncodingException} using the provided parameters. + * @param message the detail message + * @param cause the root cause + */ + JwtEncodingException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java new file mode 100644 index 00000000000..08936959e2b --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java @@ -0,0 +1,213 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.JWSSigner; +import com.nimbusds.jose.KeySourceException; +import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKMatcher; +import com.nimbusds.jose.jwk.JWKSelector; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.SecurityContext; +import com.nimbusds.jose.produce.JWSSignerFactory; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the + * JSON Web Signature (JWS) Compact Serialization format. The private/secret key used for + * signing the JWS is supplied by the {@code com.nimbusds.jose.jwk.source.JWKSource} + * provided via the constructor. + * + *

+ * NOTE: This implementation uses the Nimbus JOSE + JWT SDK. + * + * @author Joe Grandja + * @since 5.5 + * @see JwtEncoder + * @see com.nimbusds.jose.jwk.source.JWKSource + * @see com.nimbusds.jose.jwk.JWK + * @see JSON Web Token + * (JWT) + * @see JSON Web Signature + * (JWS) + * @see JWS + * Compact Serialization + * @see Nimbus + * JOSE + JWT SDK + */ +final class NimbusJwsEncoder implements JwtEncoder { + + private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s"; + + private static final Converter JWS_HEADER_CONVERTER = new JwsHeaderConverter(); + + private static final Converter JWT_CLAIMS_SET_CONVERTER = new JwtClaimsSetConverter(); + + private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory(); + + private final Map jwsSigners = new ConcurrentHashMap<>(); + + private final JWKSource jwkSource; + + /** + * Constructs a {@code NimbusJwsEncoder} using the provided parameters. + * @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource} + */ + NimbusJwsEncoder(JWKSource jwkSource) { + Assert.notNull(jwkSource, "jwkSource cannot be null"); + this.jwkSource = jwkSource; + } + + @Override + public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { + Assert.notNull(headers, "headers cannot be null"); + Assert.notNull(claims, "claims cannot be null"); + + JWSHeader jwsHeader; + try { + jwsHeader = JWS_HEADER_CONVERTER.convert(headers); + } + catch (Exception ex) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); + } + + JWK jwk = selectJwk(jwsHeader); + if (jwk == null) { + throw new JwtEncodingException( + String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); + } + + jwsHeader = addKeyIdentifierHeadersIfNecessary(jwsHeader, jwk); + headers = syncKeyIdentifierHeadersIfNecessary(headers, jwsHeader); + + JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims); + + JWSSigner jwsSigner = this.jwsSigners.computeIfAbsent(jwk, (key) -> { + try { + return JWS_SIGNER_FACTORY.createJWSSigner(key); + } + catch (JOSEException ex) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Failed to create a JWS Signer -> " + ex.getMessage()), ex); + } + }); + + SignedJWT signedJwt = new SignedJWT(jwsHeader, jwtClaimsSet); + try { + signedJwt.sign(jwsSigner); + } + catch (JOSEException ex) { + throw new JwtEncodingException( + String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to sign the JWT -> " + ex.getMessage()), ex); + } + String jws = signedJwt.serialize(); + + return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims()); + } + + private JWK selectJwk(JWSHeader jwsHeader) { + JWKSelector jwkSelector = new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader)); + + List jwks; + try { + jwks = this.jwkSource.get(jwkSelector, null); + } + catch (KeySourceException ex) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Failed to select a JWK signing key -> " + ex.getMessage()), ex); + } + + if (jwks.size() > 1) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Found multiple JWK signing keys for algorithm '" + jwsHeader.getAlgorithm().getName() + "'")); + } + + return !jwks.isEmpty() ? jwks.get(0) : null; + } + + private static JWSHeader addKeyIdentifierHeadersIfNecessary(JWSHeader jwsHeader, JWK jwk) { + // Check if headers have already been added + if (StringUtils.hasText(jwsHeader.getKeyID()) && jwsHeader.getX509CertSHA256Thumbprint() != null) { + return jwsHeader; + } + // Check if headers can be added from JWK + if (!StringUtils.hasText(jwk.getKeyID()) && jwk.getX509CertSHA256Thumbprint() == null) { + return jwsHeader; + } + + JWSHeader.Builder headerBuilder = new JWSHeader.Builder(jwsHeader); + if (!StringUtils.hasText(jwsHeader.getKeyID()) && StringUtils.hasText(jwk.getKeyID())) { + headerBuilder.keyID(jwk.getKeyID()); + } + if (jwsHeader.getX509CertSHA256Thumbprint() == null && jwk.getX509CertSHA256Thumbprint() != null) { + headerBuilder.x509CertSHA256Thumbprint(jwk.getX509CertSHA256Thumbprint()); + } + + return headerBuilder.build(); + } + + private static JoseHeader syncKeyIdentifierHeadersIfNecessary(JoseHeader joseHeader, JWSHeader jwsHeader) { + String jwsHeaderX509SHA256Thumbprint = null; + if (jwsHeader.getX509CertSHA256Thumbprint() != null) { + jwsHeaderX509SHA256Thumbprint = jwsHeader.getX509CertSHA256Thumbprint().toString(); + } + if (Objects.equals(joseHeader.getKeyId(), jwsHeader.getKeyID()) + && Objects.equals(joseHeader.getX509SHA256Thumbprint(), jwsHeaderX509SHA256Thumbprint)) { + return joseHeader; + } + + JoseHeader.Builder headerBuilder = JoseHeader.from(joseHeader); + if (!Objects.equals(joseHeader.getKeyId(), jwsHeader.getKeyID())) { + headerBuilder.keyId(jwsHeader.getKeyID()); + } + if (!Objects.equals(joseHeader.getX509SHA256Thumbprint(), jwsHeaderX509SHA256Thumbprint)) { + headerBuilder.x509SHA256Thumbprint(jwsHeaderX509SHA256Thumbprint); + } + + return headerBuilder.build(); + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizer.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizer.java new file mode 100644 index 00000000000..e256c0c2645 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizer.java @@ -0,0 +1,216 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.KeyType; +import com.nimbusds.jose.jwk.source.ImmutableJWKSet; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.SecurityContext; + +import org.springframework.http.HttpHeaders; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; +import org.springframework.security.oauth2.jose.jws.MacAlgorithm; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; + +/** + * An implementation of an + * {@link OAuth2AuthorizationGrantRequestEntityConverter.Customizer} that customizes the + * OAuth 2.0 Access Token Request by adding a signed JSON Web Token (JWS) to be used for + * client authentication at the Authorization Server's Token Endpoint. The private/secret + * key used for signing the JWS is supplied by the {@code com.nimbusds.jose.jwk.JWK} + * resolver provided via the constructor. + * + *

+ * NOTE: This implementation uses the Nimbus JOSE + JWT SDK. + * + * @param the type of {@link AbstractOAuth2AuthorizationGrantRequest} + * @author Joe Grandja + * @since 5.5 + * @see OAuth2AuthorizationGrantRequestEntityConverter.Customizer + * @see com.nimbusds.jose.jwk.JWK + * @see JwtCustomizer + * @see 2.2 + * Using JWTs for Client Authentication + * @see 4.2 + * Using Assertions for Client Authentication + * @see Nimbus + * JOSE + JWT SDK + */ +public final class NimbusJwtClientAuthenticationCustomizer + implements OAuth2AuthorizationGrantRequestEntityConverter.Customizer { + + private static final String INVALID_KEY_ERROR_CODE = "invalid_key"; + + private static final String INVALID_ALGORITHM_ERROR_CODE = "invalid_algorithm"; + + private static final String CLIENT_ASSERTION_TYPE_VALUE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; + + private final Function jwkResolver; + + private final Map jwtEncoders = new ConcurrentHashMap<>(); + + private JwtCustomizer jwtCustomizer = (request, headers, claims) -> { + }; + + /** + * Constructs a {@code NimbusJwtClientAuthenticationCustomizer} using the provided + * parameters. + * @param jwkResolver the resolver that provides the {@code com.nimbusds.jose.jwk.JWK} + * associated to the {@link ClientRegistration client} + */ + public NimbusJwtClientAuthenticationCustomizer(Function jwkResolver) { + Assert.notNull(jwkResolver, "jwkResolver cannot be null"); + this.jwkResolver = jwkResolver; + } + + @Override + public void customize(T authorizationGrantRequest, HttpHeaders headers, MultiValueMap parameters) { + Assert.notNull(authorizationGrantRequest, "authorizationGrantRequest cannot be null"); + Assert.notNull(headers, "headers cannot be null"); + Assert.notNull(parameters, "parameters cannot be null"); + + ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); + if (!ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals(clientRegistration.getClientAuthenticationMethod()) + && !ClientAuthenticationMethod.CLIENT_SECRET_JWT + .equals(clientRegistration.getClientAuthenticationMethod())) { + return; + } + + JWK jwk = this.jwkResolver.apply(clientRegistration); + if (jwk == null) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_KEY_ERROR_CODE, + "Failed to resolve JWK signing key for client registration '" + + clientRegistration.getRegistrationId() + "'.", + null); + throw new OAuth2AuthorizationException(oauth2Error); + } + + JwsAlgorithm jwsAlgorithm = resolveAlgorithm(jwk); + if (jwsAlgorithm == null) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ALGORITHM_ERROR_CODE, + "Unable to resolve JWS (signing) algorithm from JWK associated to client registration '" + + clientRegistration.getRegistrationId() + "'.", + null); + throw new OAuth2AuthorizationException(oauth2Error); + } + + JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(jwsAlgorithm); + + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(Duration.ofSeconds(30)); + + // @formatter:off + JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder() + .issuer(clientRegistration.getClientId()) + .subject(clientRegistration.getClientId()) + .audience(Collections.singletonList(clientRegistration.getProviderDetails().getTokenUri())) + .id(UUID.randomUUID().toString()) + .issuedAt(issuedAt) + .expiresAt(expiresAt); + // @formatter:on + + this.jwtCustomizer.customize(authorizationGrantRequest, headersBuilder.headers, claimsBuilder.claims); + + JoseHeader joseHeader = headersBuilder.build(); + JwtClaimsSet jwtClaimsSet = claimsBuilder.build(); + + JwtEncoder jwsEncoder = this.jwtEncoders.computeIfAbsent(clientRegistration.getRegistrationId(), + (clientRegistrationId) -> { + JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); + return new NimbusJwsEncoder(jwkSource); + }); + + Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet); + + parameters.set(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, CLIENT_ASSERTION_TYPE_VALUE); + parameters.set(OAuth2ParameterNames.CLIENT_ASSERTION, jws.getTokenValue()); + } + + /** + * Sets the {@link JwtCustomizer} to be provided the opportunity to customize the + * {@link Jwt} headers and/or claims. + * @param jwtCustomizer the {@link JwtCustomizer} to be provided the opportunity to + * customize the {@link Jwt} headers and/or claims + */ + public void setJwtCustomizer(JwtCustomizer jwtCustomizer) { + Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); + this.jwtCustomizer = jwtCustomizer; + } + + private static JwsAlgorithm resolveAlgorithm(JWK jwk) { + JwsAlgorithm jwsAlgorithm = null; + + if (jwk.getAlgorithm() != null) { + jwsAlgorithm = SignatureAlgorithm.from(jwk.getAlgorithm().getName()); + if (jwsAlgorithm == null) { + jwsAlgorithm = MacAlgorithm.from(jwk.getAlgorithm().getName()); + } + } + + if (jwsAlgorithm == null) { + if (KeyType.RSA.equals(jwk.getKeyType())) { + jwsAlgorithm = SignatureAlgorithm.RS256; + } + else if (KeyType.EC.equals(jwk.getKeyType())) { + jwsAlgorithm = SignatureAlgorithm.ES256; + } + else if (KeyType.OCT.equals(jwk.getKeyType())) { + jwsAlgorithm = MacAlgorithm.HS256; + } + } + + return jwsAlgorithm; + } + + /** + * Implementations of this interface are provided the opportunity to customize the + * {@link Jwt} headers and/or claims. + * + * @param the type of {@link AbstractOAuth2AuthorizationGrantRequest} + */ + @FunctionalInterface + interface JwtCustomizer { + + /** + * Customize the {@link Jwt} headers and/or claims. + * @param authorizationGrantRequest the authorization grant request + * @param headers the headers + * @param claims the claims + */ + void customize(T authorizationGrantRequest, Map headers, Map claims); + + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java index feae3d1f378..698ebeec212 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,8 +37,6 @@ */ public class OAuth2AuthorizationCodeGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { - private final ClientRegistration clientRegistration; - private final OAuth2AuthorizationExchange authorizationExchange; /** @@ -49,21 +47,11 @@ public class OAuth2AuthorizationCodeGrantRequest extends AbstractOAuth2Authoriza */ public OAuth2AuthorizationCodeGrantRequest(ClientRegistration clientRegistration, OAuth2AuthorizationExchange authorizationExchange) { - super(AuthorizationGrantType.AUTHORIZATION_CODE); - Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + super(AuthorizationGrantType.AUTHORIZATION_CODE, clientRegistration); Assert.notNull(authorizationExchange, "authorizationExchange cannot be null"); - this.clientRegistration = clientRegistration; this.authorizationExchange = authorizationExchange; } - /** - * Returns the {@link ClientRegistration client registration}. - * @return the {@link ClientRegistration} - */ - public ClientRegistration getClientRegistration() { - return this.clientRegistration; - } - /** * Returns the {@link OAuth2AuthorizationExchange authorization exchange}. * @return the {@link OAuth2AuthorizationExchange} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java index 5286efade03..65ec6b28fc3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.util.UriComponentsBuilder; @@ -38,12 +39,15 @@ * * @author Joe Grandja * @since 5.1 - * @see Converter + * @see OAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2AuthorizationCodeGrantRequest * @see RequestEntity */ public class OAuth2AuthorizationCodeGrantRequestEntityConverter - implements Converter> { + implements OAuth2AuthorizationGrantRequestEntityConverter { + + private Customizer customizer = (request, headers, parameters) -> { + }; /** * Returns the {@link RequestEntity} used for the Access Token Request. @@ -54,10 +58,24 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverter public RequestEntity convert(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration(); HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); - MultiValueMap formParameters = this.buildFormParameters(authorizationCodeGrantRequest); + MultiValueMap parameters = createParameters(authorizationCodeGrantRequest); + this.customizer.customize(authorizationCodeGrantRequest, headers, parameters); URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); + return new RequestEntity<>(parameters, headers, HttpMethod.POST, uri); + } + + /** + * Sets the {@link Customizer} to be provided the opportunity to customize the + * {@link HttpHeaders headers} and/or {@link MultiValueMap parameters} of the OAuth + * 2.0 Access Token Request. + * @param customizer the {@link Customizer} to be provided the opportunity to + * customize the OAuth 2.0 Access Token Request + * @since 5.5 + */ + public final void setCustomizer(Customizer customizer) { + Assert.notNull(customizer, "customizer cannot be null"); + this.customizer = customizer; } /** @@ -67,31 +85,31 @@ public RequestEntity convert(OAuth2AuthorizationCodeGrantRequest authorizatio * @return a {@link MultiValueMap} of the form parameters used for the Access Token * Request body */ - private MultiValueMap buildFormParameters( + private MultiValueMap createParameters( OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration(); OAuth2AuthorizationExchange authorizationExchange = authorizationCodeGrantRequest.getAuthorizationExchange(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); - formParameters.add(OAuth2ParameterNames.GRANT_TYPE, authorizationCodeGrantRequest.getGrantType().getValue()); - formParameters.add(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode()); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add(OAuth2ParameterNames.GRANT_TYPE, authorizationCodeGrantRequest.getGrantType().getValue()); + parameters.add(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode()); String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri(); String codeVerifier = authorizationExchange.getAuthorizationRequest() .getAttribute(PkceParameterNames.CODE_VERIFIER); if (redirectUri != null) { - formParameters.add(OAuth2ParameterNames.REDIRECT_URI, redirectUri); + parameters.add(OAuth2ParameterNames.REDIRECT_URI, redirectUri); } if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(clientRegistration.getClientAuthenticationMethod()) && !ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { - formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); } if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod()) || ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); } if (codeVerifier != null) { - formParameters.add(PkceParameterNames.CODE_VERIFIER, codeVerifier); + parameters.add(PkceParameterNames.CODE_VERIFIER, codeVerifier); } - return formParameters; + return parameters; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityConverter.java new file mode 100644 index 00000000000..2f70c79e04f --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityConverter.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.RequestEntity; +import org.springframework.util.MultiValueMap; + +/** + * Implementations of this interface are responsible for {@link Converter#convert(Object) + * converting} the provided {@link AbstractOAuth2AuthorizationGrantRequest authorization + * grant credential} to a {@link RequestEntity} representation of an OAuth 2.0 Access + * Token Request. + * + * @author Joe Grandja + * @since 5.5 + * @see Converter + * @see AbstractOAuth2AuthorizationGrantRequest + * @see RequestEntity + * @param the type of {@link AbstractOAuth2AuthorizationGrantRequest} + */ +@FunctionalInterface +public interface OAuth2AuthorizationGrantRequestEntityConverter + extends Converter> { + + /** + * Implementations of this interface are provided the opportunity to customize the + * {@link RequestEntity} representation of the OAuth 2.0 Access Token Request. + * + * @param the type of {@link AbstractOAuth2AuthorizationGrantRequest} + */ + @FunctionalInterface + interface Customizer { + + /** + * Customize the {@link HttpHeaders headers} and/or {@link MultiValueMap + * parameters} of the OAuth 2.0 Access Token Request. + * @param authorizationGrantRequest the authorization grant request + * @param headers the headers + * @param parameters the parameters + */ + void customize(T authorizationGrantRequest, HttpHeaders headers, MultiValueMap parameters); + + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java index e7fdce4bcb8..6e36a1b9dd0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ */ final class OAuth2AuthorizationGrantRequestEntityUtils { - private static HttpHeaders DEFAULT_TOKEN_REQUEST_HEADERS = getDefaultTokenRequestHeaders(); + private static final HttpHeaders DEFAULT_TOKEN_REQUEST_HEADERS = getDefaultTokenRequestHeaders(); private OAuth2AuthorizationGrantRequestEntityUtils() { } @@ -55,7 +55,7 @@ static HttpHeaders getTokenRequestHeaders(ClientRegistration clientRegistration) private static HttpHeaders getDefaultTokenRequestHeaders() { HttpHeaders headers = new HttpHeaders(); - headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8)); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); final MediaType contentType = MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); headers.setContentType(contentType); return headers; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java index f91868a0fe5..b1ab0f1f3b3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,27 +34,15 @@ */ public class OAuth2ClientCredentialsGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { - private final ClientRegistration clientRegistration; - /** * Constructs an {@code OAuth2ClientCredentialsGrantRequest} using the provided * parameters. * @param clientRegistration the client registration */ public OAuth2ClientCredentialsGrantRequest(ClientRegistration clientRegistration) { - super(AuthorizationGrantType.CLIENT_CREDENTIALS); - Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + super(AuthorizationGrantType.CLIENT_CREDENTIALS, clientRegistration); Assert.isTrue(AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()), "clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS"); - this.clientRegistration = clientRegistration; - } - - /** - * Returns the {@link ClientRegistration client registration}. - * @return the {@link ClientRegistration} - */ - public ClientRegistration getClientRegistration() { - return this.clientRegistration; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java index 73eef647219..ade884e7b94 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -38,12 +39,15 @@ * * @author Joe Grandja * @since 5.1 - * @see Converter + * @see OAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2ClientCredentialsGrantRequest * @see RequestEntity */ public class OAuth2ClientCredentialsGrantRequestEntityConverter - implements Converter> { + implements OAuth2AuthorizationGrantRequestEntityConverter { + + private Customizer customizer = (request, headers, parameters) -> { + }; /** * Returns the {@link RequestEntity} used for the Access Token Request. @@ -54,10 +58,24 @@ public class OAuth2ClientCredentialsGrantRequestEntityConverter public RequestEntity convert(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); - MultiValueMap formParameters = this.buildFormParameters(clientCredentialsGrantRequest); + MultiValueMap parameters = createParameters(clientCredentialsGrantRequest); + this.customizer.customize(clientCredentialsGrantRequest, headers, parameters); URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); + return new RequestEntity<>(parameters, headers, HttpMethod.POST, uri); + } + + /** + * Sets the {@link Customizer} to be provided the opportunity to customize the + * {@link HttpHeaders headers} and/or {@link MultiValueMap parameters} of the OAuth + * 2.0 Access Token Request. + * @param customizer the {@link Customizer} to be provided the opportunity to + * customize the OAuth 2.0 Access Token Request + * @since 5.5 + */ + public final void setCustomizer(Customizer customizer) { + Assert.notNull(customizer, "customizer cannot be null"); + this.customizer = customizer; } /** @@ -67,21 +85,21 @@ public RequestEntity convert(OAuth2ClientCredentialsGrantRequest clientCreden * @return a {@link MultiValueMap} of the form parameters used for the Access Token * Request body */ - private MultiValueMap buildFormParameters( + private MultiValueMap createParameters( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); - formParameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue()); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue()); if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { - formParameters.add(OAuth2ParameterNames.SCOPE, + parameters.add(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); } if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod()) || ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); - formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); } - return formParameters; + return parameters; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java index cc82b3f47fb..7cddb4c2e12 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,8 +33,6 @@ */ public class OAuth2PasswordGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { - private final ClientRegistration clientRegistration; - private final String username; private final String password; @@ -46,25 +44,15 @@ public class OAuth2PasswordGrantRequest extends AbstractOAuth2AuthorizationGrant * @param password the resource owner's password */ public OAuth2PasswordGrantRequest(ClientRegistration clientRegistration, String username, String password) { - super(AuthorizationGrantType.PASSWORD); - Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + super(AuthorizationGrantType.PASSWORD, clientRegistration); Assert.isTrue(AuthorizationGrantType.PASSWORD.equals(clientRegistration.getAuthorizationGrantType()), "clientRegistration.authorizationGrantType must be AuthorizationGrantType.PASSWORD"); Assert.hasText(username, "username cannot be empty"); Assert.hasText(password, "password cannot be empty"); - this.clientRegistration = clientRegistration; this.username = username; this.password = password; } - /** - * Returns the {@link ClientRegistration client registration}. - * @return the {@link ClientRegistration} - */ - public ClientRegistration getClientRegistration() { - return this.clientRegistration; - } - /** * Returns the resource owner's username. * @return the resource owner's username diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java index b0f2e8691ca..aee6d57f6a9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -38,12 +39,15 @@ * * @author Joe Grandja * @since 5.2 - * @see Converter + * @see OAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2PasswordGrantRequest * @see RequestEntity */ public class OAuth2PasswordGrantRequestEntityConverter - implements Converter> { + implements OAuth2AuthorizationGrantRequestEntityConverter { + + private Customizer customizer = (request, headers, parameters) -> { + }; /** * Returns the {@link RequestEntity} used for the Access Token Request. @@ -54,10 +58,24 @@ public class OAuth2PasswordGrantRequestEntityConverter public RequestEntity convert(OAuth2PasswordGrantRequest passwordGrantRequest) { ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration(); HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); - MultiValueMap formParameters = buildFormParameters(passwordGrantRequest); + MultiValueMap parameters = createParameters(passwordGrantRequest); + this.customizer.customize(passwordGrantRequest, headers, parameters); URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); + return new RequestEntity<>(parameters, headers, HttpMethod.POST, uri); + } + + /** + * Sets the {@link Customizer} to be provided the opportunity to customize the + * {@link HttpHeaders headers} and/or {@link MultiValueMap parameters} of the OAuth + * 2.0 Access Token Request. + * @param customizer the {@link Customizer} to be provided the opportunity to + * customize the OAuth 2.0 Access Token Request + * @since 5.5 + */ + public final void setCustomizer(Customizer customizer) { + Assert.notNull(customizer, "customizer cannot be null"); + this.customizer = customizer; } /** @@ -67,22 +85,22 @@ public RequestEntity convert(OAuth2PasswordGrantRequest passwordGrantRequest) * @return a {@link MultiValueMap} of the form parameters used for the Access Token * Request body */ - private MultiValueMap buildFormParameters(OAuth2PasswordGrantRequest passwordGrantRequest) { + private MultiValueMap createParameters(OAuth2PasswordGrantRequest passwordGrantRequest) { ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); - formParameters.add(OAuth2ParameterNames.GRANT_TYPE, passwordGrantRequest.getGrantType().getValue()); - formParameters.add(OAuth2ParameterNames.USERNAME, passwordGrantRequest.getUsername()); - formParameters.add(OAuth2ParameterNames.PASSWORD, passwordGrantRequest.getPassword()); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add(OAuth2ParameterNames.GRANT_TYPE, passwordGrantRequest.getGrantType().getValue()); + parameters.add(OAuth2ParameterNames.USERNAME, passwordGrantRequest.getUsername()); + parameters.add(OAuth2ParameterNames.PASSWORD, passwordGrantRequest.getPassword()); if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { - formParameters.add(OAuth2ParameterNames.SCOPE, + parameters.add(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); } if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod()) || ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); - formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); } - return formParameters; + return parameters; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java index 5bb5d127586..d57fad0e97c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,8 +39,6 @@ */ public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { - private final ClientRegistration clientRegistration; - private final OAuth2AccessToken accessToken; private final OAuth2RefreshToken refreshToken; @@ -67,25 +65,15 @@ public OAuth2RefreshTokenGrantRequest(ClientRegistration clientRegistration, OAu */ public OAuth2RefreshTokenGrantRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, OAuth2RefreshToken refreshToken, Set scopes) { - super(AuthorizationGrantType.REFRESH_TOKEN); - Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + super(AuthorizationGrantType.REFRESH_TOKEN, clientRegistration); Assert.notNull(accessToken, "accessToken cannot be null"); Assert.notNull(refreshToken, "refreshToken cannot be null"); - this.clientRegistration = clientRegistration; this.accessToken = accessToken; this.refreshToken = refreshToken; this.scopes = Collections .unmodifiableSet((scopes != null) ? new LinkedHashSet<>(scopes) : Collections.emptySet()); } - /** - * Returns the authorized client's {@link ClientRegistration registration}. - * @return the {@link ClientRegistration} - */ - public ClientRegistration getClientRegistration() { - return this.clientRegistration; - } - /** * Returns the {@link OAuth2AccessToken access token} credential granted. * @return the {@link OAuth2AccessToken} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java index 683fadd4e64..3c8fdae08b0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -38,12 +39,15 @@ * * @author Joe Grandja * @since 5.2 - * @see Converter + * @see OAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2RefreshTokenGrantRequest * @see RequestEntity */ public class OAuth2RefreshTokenGrantRequestEntityConverter - implements Converter> { + implements OAuth2AuthorizationGrantRequestEntityConverter { + + private Customizer customizer = (request, headers, parameters) -> { + }; /** * Returns the {@link RequestEntity} used for the Access Token Request. @@ -54,10 +58,24 @@ public class OAuth2RefreshTokenGrantRequestEntityConverter public RequestEntity convert(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); - MultiValueMap formParameters = buildFormParameters(refreshTokenGrantRequest); + MultiValueMap parameters = createParameters(refreshTokenGrantRequest); + this.customizer.customize(refreshTokenGrantRequest, headers, parameters); URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); + return new RequestEntity<>(parameters, headers, HttpMethod.POST, uri); + } + + /** + * Sets the {@link Customizer} to be provided the opportunity to customize the + * {@link HttpHeaders headers} and/or {@link MultiValueMap parameters} of the OAuth + * 2.0 Access Token Request. + * @param customizer the {@link Customizer} to be provided the opportunity to + * customize the OAuth 2.0 Access Token Request + * @since 5.5 + */ + public final void setCustomizer(Customizer customizer) { + Assert.notNull(customizer, "customizer cannot be null"); + this.customizer = customizer; } /** @@ -67,22 +85,21 @@ public RequestEntity convert(OAuth2RefreshTokenGrantRequest refreshTokenGrant * @return a {@link MultiValueMap} of the form parameters used for the Access Token * Request body */ - private MultiValueMap buildFormParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + private MultiValueMap createParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); - formParameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue()); - formParameters.add(OAuth2ParameterNames.REFRESH_TOKEN, - refreshTokenGrantRequest.getRefreshToken().getTokenValue()); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue()); + parameters.add(OAuth2ParameterNames.REFRESH_TOKEN, refreshTokenGrantRequest.getRefreshToken().getTokenValue()); if (!CollectionUtils.isEmpty(refreshTokenGrantRequest.getScopes())) { - formParameters.add(OAuth2ParameterNames.SCOPE, + parameters.add(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(refreshTokenGrantRequest.getScopes(), " ")); } if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod()) || ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); - formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); } - return formParameters; + return parameters; } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java index 1864dd6f1bf..07cc2621170 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,13 @@ package org.springframework.security.oauth2.client.endpoint; +import java.nio.charset.StandardCharsets; import java.time.Instant; +import java.util.function.Function; +import javax.crypto.spec.SecretKeySpec; + +import com.nimbusds.jose.jwk.JWK; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -29,7 +34,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; @@ -37,6 +42,8 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -49,32 +56,25 @@ */ public class DefaultAuthorizationCodeTokenResponseClientTests { - private DefaultAuthorizationCodeTokenResponseClient tokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient(); + private DefaultAuthorizationCodeTokenResponseClient tokenResponseClient; - private ClientRegistration clientRegistration; + private ClientRegistration.Builder clientRegistration; private MockWebServer server; @Before public void setup() throws Exception { + this.tokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient(); this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); // @formatter:off - this.clientRegistration = ClientRegistration - .withRegistrationId("registration-1") + this.clientRegistration = TestClientRegistrations.clientRegistration() .clientId("client-1") .clientSecret("secret") - .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .redirectUri("https://client.com/callback/client-1") - .scope("read", "write") - .authorizationUri("https://provider.com/oauth2/authorize") .tokenUri(tokenUri) - .userInfoUri("https://provider.com/user") - .userNameAttributeName("id") - .clientName("client-1") - .build(); + .scope("read", "write"); // @formatter:on } @@ -114,11 +114,11 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); Instant expiresAtBefore = Instant.now().plusSeconds(3600); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient - .getTokenResponse(this.authorizationCodeGrantRequest()); + .getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build())); Instant expiresAtAfter = Instant.now().plusSeconds(3600); RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); - assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); String formParameters = recordedRequest.getBody().readUtf8(); @@ -136,7 +136,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t } @Test - public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception { + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { // @formatter:off String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" @@ -145,13 +145,13 @@ public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeader + "}\n"; // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()); + this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build())); RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); } @Test - public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { // @formatter:off String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" @@ -160,9 +160,9 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen + "}\n"; // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - ClientRegistration clientRegistration = this.from(this.clientRegistration) + ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST).build(); - this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest(clientRegistration)); + this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration)); RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); String formParameters = recordedRequest.getBody().readUtf8(); @@ -170,6 +170,79 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen assertThat(formParameters).contains("client_secret=secret"); } + @Test + public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) + .build(); + // @formatter:on + + // Configure Jwt client authentication customizer + SecretKeySpec secretKey = new SecretKeySpec( + clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); + JWK jwk = TestJwks.jwk(secretKey).build(); + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationCustomizer(jwkResolver); + + this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration)); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters) + .contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"); + assertThat(formParameters).contains("client_assertion="); + } + + @Test + public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + + // Configure Jwt client authentication customizer + JWK jwk = TestJwks.DEFAULT_RSA_JWK; + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationCustomizer(jwkResolver); + + this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration)); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters) + .contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"); + assertThat(formParameters).contains("client_assertion="); + } + + private void configureJwtClientAuthenticationCustomizer(Function jwkResolver) { + NimbusJwtClientAuthenticationCustomizer jwtClientAuthenticationCustomizer = new NimbusJwtClientAuthenticationCustomizer<>( + jwkResolver); + OAuth2AuthorizationCodeGrantRequestEntityConverter requestEntityConverter = new OAuth2AuthorizationCodeGrantRequestEntityConverter(); + requestEntityConverter.setCustomizer(jwtClientAuthenticationCustomizer); + this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter); + } + @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { // @formatter:off @@ -181,7 +254,8 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); assertThatExceptionOfType(OAuth2AuthorizationException.class) - .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) + .isThrownBy(() -> this.tokenResponseClient + .getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build()))) .withMessageContaining( "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") .withMessageContaining("tokenType cannot be null"); @@ -196,7 +270,8 @@ public void getTokenResponseWhenSuccessResponseAndMissingTokenTypeParameterThenT // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); assertThatExceptionOfType(OAuth2AuthorizationException.class) - .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) + .isThrownBy(() -> this.tokenResponseClient + .getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build()))) .withMessageContaining( "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") .withMessageContaining("tokenType cannot be null"); @@ -215,7 +290,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient - .getTokenResponse(this.authorizationCodeGrantRequest()); + .getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build())); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); } @@ -231,16 +306,16 @@ public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessToke // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient - .getTokenResponse(this.authorizationCodeGrantRequest()); + .getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build())); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write"); } @Test public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() { String invalidTokenUri = "https://invalid-provider.com/oauth2/token"; - ClientRegistration clientRegistration = this.from(this.clientRegistration).tokenUri(invalidTokenUri).build(); + ClientRegistration clientRegistration = this.clientRegistration.tokenUri(invalidTokenUri).build(); assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy( - () -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest(clientRegistration))) + () -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration))) .withMessageContaining( "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); } @@ -260,7 +335,8 @@ public void getTokenResponseWhenMalformedResponseThenThrowOAuth2AuthorizationExc // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); assertThatExceptionOfType(OAuth2AuthorizationException.class) - .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) + .isThrownBy(() -> this.tokenResponseClient + .getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build()))) .withMessageContaining( "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); } @@ -270,7 +346,8 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); assertThatExceptionOfType(OAuth2AuthorizationException.class) - .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) + .isThrownBy(() -> this.tokenResponseClient + .getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build()))) .withMessageContaining("[unauthorized_client]"); } @@ -278,15 +355,12 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); assertThatExceptionOfType(OAuth2AuthorizationException.class) - .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) + .isThrownBy(() -> this.tokenResponseClient + .getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build()))) .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve " + "the OAuth 2.0 Access Token Response"); } - private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest() { - return this.authorizationCodeGrantRequest(this.clientRegistration); - } - private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest(ClientRegistration clientRegistration) { OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .clientId(clientRegistration.getClientId()).state("state-1234") @@ -303,22 +377,4 @@ private MockResponse jsonResponse(String json) { return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } - private ClientRegistration.Builder from(ClientRegistration registration) { - // @formatter:off - return ClientRegistration.withRegistrationId(registration.getRegistrationId()) - .clientId(registration.getClientId()) - .clientSecret(registration.getClientSecret()) - .clientAuthenticationMethod(registration.getClientAuthenticationMethod()) - .authorizationGrantType(registration.getAuthorizationGrantType()) - .redirectUri(registration.getRedirectUri()) - .scope(registration.getScopes()) - .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) - .tokenUri(registration.getProviderDetails().getTokenUri()) - .userInfoUri(registration.getProviderDetails().getUserInfoEndpoint().getUri()) - .userNameAttributeName( - registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()) - .clientName(registration.getClientName()); - // @formatter:on - } - } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java index fee6942fa56..2ce24fe9c9b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,13 @@ package org.springframework.security.oauth2.client.endpoint; +import java.nio.charset.StandardCharsets; import java.time.Instant; +import java.util.function.Function; +import javax.crypto.spec.SecretKeySpec; + +import com.nimbusds.jose.jwk.JWK; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -29,11 +34,13 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -46,26 +53,24 @@ */ public class DefaultClientCredentialsTokenResponseClientTests { - private DefaultClientCredentialsTokenResponseClient tokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); + private DefaultClientCredentialsTokenResponseClient tokenResponseClient; - private ClientRegistration clientRegistration; + private ClientRegistration.Builder clientRegistration; private MockWebServer server; @Before public void setup() throws Exception { + this.tokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); // @formatter:off - this.clientRegistration = ClientRegistration.withRegistrationId("registration-1") + this.clientRegistration = TestClientRegistrations.clientCredentials() .clientId("client-1") .clientSecret("secret") - .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .scope("read", "write") .tokenUri(tokenUri) - .build(); + .scope("read", "write"); // @formatter:on } @@ -110,13 +115,13 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); Instant expiresAtBefore = Instant.now().plusSeconds(3600); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - this.clientRegistration); + this.clientRegistration.build()); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient .getTokenResponse(clientCredentialsGrantRequest); Instant expiresAtAfter = Instant.now().plusSeconds(3600); RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); - assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); String formParameters = recordedRequest.getBody().readUtf8(); @@ -133,7 +138,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t } @Test - public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception { + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { // @formatter:off String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" @@ -143,14 +148,14 @@ public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeader // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - this.clientRegistration); + this.clientRegistration.build()); this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); } @Test - public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { // @formatter:off String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" @@ -159,7 +164,7 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen + "}\n"; // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - ClientRegistration clientRegistration = this.from(this.clientRegistration) + ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST).build(); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( clientRegistration); @@ -171,6 +176,83 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen assertThat(formParameters).contains("client_secret=secret"); } + @Test + public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) + .build(); + // @formatter:on + + // Configure Jwt client authentication customizer + SecretKeySpec secretKey = new SecretKeySpec( + clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); + JWK jwk = TestJwks.jwk(secretKey).build(); + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationCustomizer(jwkResolver); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters) + .contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"); + assertThat(formParameters).contains("client_assertion="); + } + + @Test + public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + + // Configure Jwt client authentication customizer + JWK jwk = TestJwks.DEFAULT_RSA_JWK; + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationCustomizer(jwkResolver); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters) + .contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"); + assertThat(formParameters).contains("client_assertion="); + } + + private void configureJwtClientAuthenticationCustomizer(Function jwkResolver) { + NimbusJwtClientAuthenticationCustomizer jwtClientAuthenticationCustomizer = new NimbusJwtClientAuthenticationCustomizer<>( + jwkResolver); + OAuth2ClientCredentialsGrantRequestEntityConverter requestEntityConverter = new OAuth2ClientCredentialsGrantRequestEntityConverter(); + requestEntityConverter.setCustomizer(jwtClientAuthenticationCustomizer); + this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter); + } + @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { // @formatter:off @@ -182,7 +264,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - this.clientRegistration); + this.clientRegistration.build()); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) .withMessageContaining( @@ -195,7 +277,7 @@ public void getTokenResponseWhenSuccessResponseAndMissingTokenTypeParameterThenT String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - this.clientRegistration); + this.clientRegistration.build()); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) .withMessageContaining( @@ -215,7 +297,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - this.clientRegistration); + this.clientRegistration.build()); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient .getTokenResponse(clientCredentialsGrantRequest); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); @@ -232,7 +314,7 @@ public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessToke // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - this.clientRegistration); + this.clientRegistration.build()); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient .getTokenResponse(clientCredentialsGrantRequest); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write"); @@ -241,7 +323,7 @@ public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessToke @Test public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() { String invalidTokenUri = "https://invalid-provider.com/oauth2/token"; - ClientRegistration clientRegistration = this.from(this.clientRegistration).tokenUri(invalidTokenUri).build(); + ClientRegistration clientRegistration = this.clientRegistration.tokenUri(invalidTokenUri).build(); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( clientRegistration); assertThatExceptionOfType(OAuth2AuthorizationException.class) @@ -264,7 +346,7 @@ public void getTokenResponseWhenMalformedResponseThenThrowOAuth2AuthorizationExc // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - this.clientRegistration); + this.clientRegistration.build()); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) .withMessageContaining( @@ -280,7 +362,7 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti // @formatter:on this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - this.clientRegistration); + this.clientRegistration.build()); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) .withMessageContaining("[unauthorized_client]"); @@ -290,7 +372,7 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - this.clientRegistration); + this.clientRegistration.build()); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) .withMessageContaining( @@ -301,16 +383,4 @@ private MockResponse jsonResponse(String json) { return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } - private ClientRegistration.Builder from(ClientRegistration registration) { - // @formatter:off - return ClientRegistration.withRegistrationId(registration.getRegistrationId()) - .clientId(registration.getClientId()) - .clientSecret(registration.getClientSecret()) - .clientAuthenticationMethod(registration.getClientAuthenticationMethod()) - .authorizationGrantType(registration.getAuthorizationGrantType()) - .scope(registration.getScopes()) - .tokenUri(registration.getProviderDetails().getTokenUri()); - // @formatter:on - } - } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java index 66d098880bf..8fb756a8112 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,13 @@ package org.springframework.security.oauth2.client.endpoint; +import java.nio.charset.StandardCharsets; import java.time.Instant; +import java.util.function.Function; +import javax.crypto.spec.SecretKeySpec; + +import com.nimbusds.jose.jwk.JWK; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -30,11 +35,12 @@ import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -47,9 +53,9 @@ */ public class DefaultPasswordTokenResponseClientTests { - private DefaultPasswordTokenResponseClient tokenResponseClient = new DefaultPasswordTokenResponseClient(); + private DefaultPasswordTokenResponseClient tokenResponseClient; - private ClientRegistration.Builder clientRegistrationBuilder; + private ClientRegistration.Builder clientRegistration; private String username = "user1"; @@ -59,11 +65,15 @@ public class DefaultPasswordTokenResponseClientTests { @Before public void setup() throws Exception { + this.tokenResponseClient = new DefaultPasswordTokenResponseClient(); this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); - this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration() - .authorizationGrantType(AuthorizationGrantType.PASSWORD).scope("read", "write").tokenUri(tokenUri); + // @formatter:off + this.clientRegistration = TestClientRegistrations.password() + .scope("read", "write") + .tokenUri(tokenUri); + // @formatter:on } @After @@ -97,14 +107,14 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); Instant expiresAtBefore = Instant.now().plusSeconds(3600); - ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); + ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, this.username, this.password); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest); Instant expiresAtAfter = Instant.now().plusSeconds(3600); RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); - assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); String formParameters = recordedRequest.getBody().readUtf8(); @@ -121,7 +131,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t } @Test - public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { // @formatter:off String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" @@ -130,7 +140,7 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen + "}\n"; // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - ClientRegistration clientRegistration = this.clientRegistrationBuilder + ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST).build(); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, this.username, this.password); @@ -142,6 +152,83 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen assertThat(formParameters).contains("client_secret=client-secret"); } + @Test + public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) + .build(); + // @formatter:on + + // Configure Jwt client authentication customizer + SecretKeySpec secretKey = new SecretKeySpec( + clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); + JWK jwk = TestJwks.jwk(secretKey).build(); + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationCustomizer(jwkResolver); + + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, + this.username, this.password); + this.tokenResponseClient.getTokenResponse(passwordGrantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters) + .contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"); + assertThat(formParameters).contains("client_assertion="); + } + + @Test + public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + + // Configure Jwt client authentication customizer + JWK jwk = TestJwks.DEFAULT_RSA_JWK; + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationCustomizer(jwkResolver); + + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, + this.username, this.password); + this.tokenResponseClient.getTokenResponse(passwordGrantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters) + .contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"); + assertThat(formParameters).contains("client_assertion="); + } + + private void configureJwtClientAuthenticationCustomizer(Function jwkResolver) { + NimbusJwtClientAuthenticationCustomizer jwtClientAuthenticationCustomizer = new NimbusJwtClientAuthenticationCustomizer<>( + jwkResolver); + OAuth2PasswordGrantRequestEntityConverter requestEntityConverter = new OAuth2PasswordGrantRequestEntityConverter(); + requestEntityConverter.setCustomizer(jwtClientAuthenticationCustomizer); + this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter); + } + @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { // @formatter:off @@ -153,7 +240,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( - this.clientRegistrationBuilder.build(), this.username, this.password); + this.clientRegistration.build(), this.username, this.password); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest)) .withMessageContaining( @@ -173,7 +260,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( - this.clientRegistrationBuilder.build(), this.username, this.password); + this.clientRegistration.build(), this.username, this.password); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); @@ -186,7 +273,7 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( - this.clientRegistrationBuilder.build(), this.username, this.password); + this.clientRegistration.build(), this.username, this.password); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest)) .withMessageContaining("[unauthorized_client]"); @@ -196,7 +283,7 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( - this.clientRegistrationBuilder.build(), this.username, this.password); + this.clientRegistration.build(), this.username, this.password); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest)) .withMessageContaining("[invalid_token_response] An error occurred while attempting to " diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java index c997e94ea9d..d5f1413d453 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,14 @@ package org.springframework.security.oauth2.client.endpoint; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collections; +import java.util.function.Function; +import javax.crypto.spec.SecretKeySpec; + +import com.nimbusds.jose.jwk.JWK; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -38,6 +43,8 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -50,9 +57,9 @@ */ public class DefaultRefreshTokenTokenResponseClientTests { - private DefaultRefreshTokenTokenResponseClient tokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + private DefaultRefreshTokenTokenResponseClient tokenResponseClient; - private ClientRegistration.Builder clientRegistrationBuilder; + private ClientRegistration.Builder clientRegistration; private OAuth2AccessToken accessToken; @@ -62,10 +69,11 @@ public class DefaultRefreshTokenTokenResponseClientTests { @Before public void setup() throws Exception { + this.tokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); - this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().tokenUri(tokenUri); + this.clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(tokenUri); this.accessToken = TestOAuth2AccessTokens.scopes("read", "write"); this.refreshToken = TestOAuth2RefreshTokens.refreshToken(); } @@ -102,13 +110,13 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); Instant expiresAtBefore = Instant.now().plusSeconds(3600); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + this.clientRegistration.build(), this.accessToken, this.refreshToken); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient .getTokenResponse(refreshTokenGrantRequest); Instant expiresAtAfter = Instant.now().plusSeconds(3600); RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); - assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); @@ -124,11 +132,16 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t } @Test - public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { - String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - ClientRegistration clientRegistration = this.clientRegistrationBuilder + ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST).build(); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); @@ -140,6 +153,83 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen assertThat(formParameters).contains("client_secret=client-secret"); } + @Test + public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) + .build(); + // @formatter:on + + // Configure Jwt client authentication customizer + SecretKeySpec secretKey = new SecretKeySpec( + clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); + JWK jwk = TestJwks.jwk(secretKey).build(); + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationCustomizer(jwkResolver); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters) + .contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"); + assertThat(formParameters).contains("client_assertion="); + } + + @Test + public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + + // Configure Jwt client authentication customizer + JWK jwk = TestJwks.DEFAULT_RSA_JWK; + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationCustomizer(jwkResolver); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters) + .contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"); + assertThat(formParameters).contains("client_assertion="); + } + + private void configureJwtClientAuthenticationCustomizer(Function jwkResolver) { + NimbusJwtClientAuthenticationCustomizer jwtClientAuthenticationCustomizer = new NimbusJwtClientAuthenticationCustomizer<>( + jwkResolver); + OAuth2RefreshTokenGrantRequestEntityConverter requestEntityConverter = new OAuth2RefreshTokenGrantRequestEntityConverter(); + requestEntityConverter.setCustomizer(jwtClientAuthenticationCustomizer); + this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter); + } + @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { // @formatter:off @@ -151,7 +241,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + this.clientRegistration.build(), this.accessToken, this.refreshToken); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) .withMessageContaining("[invalid_token_response] An error occurred while attempting to " @@ -171,8 +261,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, - Collections.singleton("read")); + this.clientRegistration.build(), this.accessToken, this.refreshToken, Collections.singleton("read")); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient .getTokenResponse(refreshTokenGrantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); @@ -186,7 +275,7 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + this.clientRegistration.build(), this.accessToken, this.refreshToken); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) .withMessageContaining("[unauthorized_client]"); @@ -196,7 +285,7 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + this.clientRegistration.build(), this.accessToken, this.refreshToken); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) .withMessageContaining("[invalid_token_response] An error occurred while attempting to " diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderTests.java new file mode 100644 index 00000000000..0556bbb473e --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderTests.java @@ -0,0 +1,123 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.Test; + +import org.springframework.security.oauth2.jose.JwaAlgorithm; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * Tests for {@link JoseHeader}. + * + * @author Joe Grandja + */ +public class JoseHeaderTests { + + @Test + public void withAlgorithmWhenNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.withAlgorithm(null)) + .isInstanceOf(IllegalArgumentException.class).withMessage("jwaAlgorithm cannot be null"); + } + + @Test + public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { + JoseHeader expectedJoseHeader = TestJoseHeaders.joseHeader().build(); + + // @formatter:off + JoseHeader joseHeader = JoseHeader.withAlgorithm(expectedJoseHeader.getAlgorithm()) + .jwkSetUri(expectedJoseHeader.getJwkSetUri().toExternalForm()) + .jwk(expectedJoseHeader.getJwk()) + .keyId(expectedJoseHeader.getKeyId()) + .x509Uri(expectedJoseHeader.getX509Uri().toExternalForm()) + .x509CertificateChain(expectedJoseHeader.getX509CertificateChain()) + .x509SHA1Thumbprint(expectedJoseHeader.getX509SHA1Thumbprint()) + .x509SHA256Thumbprint(expectedJoseHeader.getX509SHA256Thumbprint()) + .type(expectedJoseHeader.getType()) + .contentType(expectedJoseHeader.getContentType()) + .headers((headers) -> headers.put("custom-header-name", "custom-header-value")) + .build(); + // @formatter:on + + assertThat(joseHeader.getAlgorithm()).isEqualTo(expectedJoseHeader.getAlgorithm()); + assertThat(joseHeader.getJwkSetUri()).isEqualTo(expectedJoseHeader.getJwkSetUri()); + assertThat(joseHeader.getJwk()).isEqualTo(expectedJoseHeader.getJwk()); + assertThat(joseHeader.getKeyId()).isEqualTo(expectedJoseHeader.getKeyId()); + assertThat(joseHeader.getX509Uri()).isEqualTo(expectedJoseHeader.getX509Uri()); + assertThat(joseHeader.getX509CertificateChain()).isEqualTo(expectedJoseHeader.getX509CertificateChain()); + assertThat(joseHeader.getX509SHA1Thumbprint()).isEqualTo(expectedJoseHeader.getX509SHA1Thumbprint()); + assertThat(joseHeader.getX509SHA256Thumbprint()).isEqualTo(expectedJoseHeader.getX509SHA256Thumbprint()); + assertThat(joseHeader.getType()).isEqualTo(expectedJoseHeader.getType()); + assertThat(joseHeader.getContentType()).isEqualTo(expectedJoseHeader.getContentType()); + assertThat(joseHeader.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + assertThat(joseHeader.getHeaders()).isEqualTo(expectedJoseHeader.getHeaders()); + } + + @Test + public void fromWhenNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.from(null)) + .isInstanceOf(IllegalArgumentException.class).withMessage("headers cannot be null"); + } + + @Test + public void fromWhenHeadersProvidedThenCopied() { + JoseHeader expectedJoseHeader = TestJoseHeaders.joseHeader().build(); + JoseHeader joseHeader = JoseHeader.from(expectedJoseHeader).build(); + assertThat(joseHeader.getHeaders()).isEqualTo(expectedJoseHeader.getHeaders()); + } + + @Test + public void headerWhenNameNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).header(null, "value")) + .withMessage("name cannot be empty"); + } + + @Test + public void headerWhenValueNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).header("name", null)) + .withMessage("value cannot be null"); + } + + @Test + public void getHeaderWhenNullThenThrowIllegalArgumentException() { + JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); + + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> joseHeader.getHeader(null)) + .isInstanceOf(IllegalArgumentException.class).withMessage("name cannot be empty"); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetTests.java new file mode 100644 index 00000000000..2b10762c66f --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * Tests for {@link JwtClaimsSet}. + * + * @author Joe Grandja + */ +public class JwtClaimsSetTests { + + @Test + public void buildWhenClaimsEmptyThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JwtClaimsSet.builder().build()) + .isInstanceOf(IllegalArgumentException.class).withMessage("claims cannot be empty"); + } + + @Test + public void buildWhenAllClaimsProvidedThenAllClaimsAreSet() { + JwtClaimsSet expectedJwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + // @formatter:off + JwtClaimsSet jwtClaimsSet = JwtClaimsSet.builder() + .issuer(expectedJwtClaimsSet.getIssuer().toExternalForm()) + .subject(expectedJwtClaimsSet.getSubject()) + .audience(expectedJwtClaimsSet.getAudience()) + .issuedAt(expectedJwtClaimsSet.getIssuedAt()) + .notBefore(expectedJwtClaimsSet.getNotBefore()) + .expiresAt(expectedJwtClaimsSet.getExpiresAt()) + .id(expectedJwtClaimsSet.getId()) + .claims((claims) -> claims.put("custom-claim-name", "custom-claim-value")) + .build(); + // @formatter:on + + assertThat(jwtClaimsSet.getIssuer()).isEqualTo(expectedJwtClaimsSet.getIssuer()); + assertThat(jwtClaimsSet.getSubject()).isEqualTo(expectedJwtClaimsSet.getSubject()); + assertThat(jwtClaimsSet.getAudience()).isEqualTo(expectedJwtClaimsSet.getAudience()); + assertThat(jwtClaimsSet.getIssuedAt()).isEqualTo(expectedJwtClaimsSet.getIssuedAt()); + assertThat(jwtClaimsSet.getNotBefore()).isEqualTo(expectedJwtClaimsSet.getNotBefore()); + assertThat(jwtClaimsSet.getExpiresAt()).isEqualTo(expectedJwtClaimsSet.getExpiresAt()); + assertThat(jwtClaimsSet.getId()).isEqualTo(expectedJwtClaimsSet.getId()); + assertThat(jwtClaimsSet.getClaim("custom-claim-name")).isEqualTo("custom-claim-value"); + assertThat(jwtClaimsSet.getClaims()).isEqualTo(expectedJwtClaimsSet.getClaims()); + } + + @Test + public void fromWhenNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JwtClaimsSet.from(null)) + .isInstanceOf(IllegalArgumentException.class).withMessage("claims cannot be null"); + } + + @Test + public void fromWhenClaimsProvidedThenCopied() { + JwtClaimsSet expectedJwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + JwtClaimsSet jwtClaimsSet = JwtClaimsSet.from(expectedJwtClaimsSet).build(); + assertThat(jwtClaimsSet.getClaims()).isEqualTo(expectedJwtClaimsSet.getClaims()); + } + + @Test + public void claimWhenNameNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> JwtClaimsSet.builder().claim(null, "value")).withMessage("name cannot be empty"); + } + + @Test + public void claimWhenValueNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> JwtClaimsSet.builder().claim("name", null)).withMessage("value cannot be null"); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoderTests.java new file mode 100644 index 00000000000..2d29a370ea0 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoderTests.java @@ -0,0 +1,347 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.security.interfaces.ECPrivateKey; +import java.security.interfaces.ECPublicKey; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import com.nimbusds.jose.KeySourceException; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKSelector; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.OctetSequenceKey; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.SecurityContext; +import com.nimbusds.jose.util.Base64URL; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * Tests for {@link NimbusJwsEncoder}. + * + * @author Joe Grandja + */ +public class NimbusJwsEncoderTests { + + private List jwkList; + + private JWKSource jwkSource; + + private NimbusJwsEncoder jwsEncoder; + + @Before + public void setUp() { + this.jwkList = new ArrayList<>(); + this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList)); + this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); + } + + @Test + public void constructorWhenJwkSourceNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwsEncoder(null)) + .withMessage("jwkSource cannot be null"); + } + + @Test + public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.encode(null, jwtClaimsSet)) + .withMessage("headers cannot be null"); + } + + @Test + public void encodeWhenClaimsNullThenThrowIllegalArgumentException() { + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + + assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.encode(joseHeader, null)) + .withMessage("claims cannot be null"); + } + + @Test + public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exception { + this.jwkSource = mock(JWKSource.class); + this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); + given(this.jwkSource.get(any(), any())).willThrow(new KeySourceException("key source error")); + + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + assertThatExceptionOfType(JwtEncodingException.class) + .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)) + .withMessageContaining("Failed to select a JWK signing key -> key source error"); + } + + @Test + public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + this.jwkList.add(rsaJwk); + this.jwkList.add(rsaJwk); + + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + assertThatExceptionOfType(JwtEncodingException.class) + .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)) + .withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'"); + } + + @Test + public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() { + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + assertThatExceptionOfType(JwtEncodingException.class) + .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)) + .withMessageContaining("Failed to select a JWK signing key"); + } + + @Test + public void encodeWhenJwkSelectWithProvidedKidThenSelected() { + // @formatter:off + RSAKey rsaJwk1 = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) + .keyID("rsa-jwk-1") + .build(); + this.jwkList.add(rsaJwk1); + RSAKey rsaJwk2 = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) + .keyID("rsa-jwk-2") + .build(); + this.jwkList.add(rsaJwk2); + // @formatter:on + + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).keyId(rsaJwk2.getKeyID()).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet); + + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk2.getKeyID()); + } + + @Test + public void encodeWhenJwkSelectWithProvidedX5TS256ThenSelected() { + // @formatter:off + RSAKey rsaJwk1 = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) + .x509CertSHA256Thumbprint(new Base64URL("x509CertSHA256Thumbprint-1")) + .keyID(null) + .build(); + this.jwkList.add(rsaJwk1); + RSAKey rsaJwk2 = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) + .x509CertSHA256Thumbprint(new Base64URL("x509CertSHA256Thumbprint-2")) + .keyID(null) + .build(); + this.jwkList.add(rsaJwk2); + // @formatter:on + + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256) + .x509SHA256Thumbprint(rsaJwk1.getX509CertSHA256Thumbprint().toString()).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet); + + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T_S256)) + .isEqualTo(rsaJwk1.getX509CertSHA256Thumbprint().toString()); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isNull(); + } + + @Test + public void encodeWhenJwkUseEncryptionThenThrowJwtEncodingException() throws Exception { + // @formatter:off + RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) + .keyUse(KeyUse.ENCRYPTION) + .build(); + // @formatter:on + + this.jwkSource = mock(JWKSource.class); + this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + assertThatExceptionOfType(JwtEncodingException.class) + .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)).withMessageContaining( + "Failed to create a JWS Signer -> The JWK use must be sig (signature) or unspecified"); + } + + @Test + public void encodeWhenSuccessThenDecodes() throws Exception { + // @formatter:off + RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) + .keyID("rsa-jwk-1") + .x509CertSHA256Thumbprint(new Base64URL("x509CertSHA256Thumbprint-1")) + .build(); + this.jwkList.add(rsaJwk); + // @formatter:on + + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet); + + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(joseHeader.getAlgorithm()); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.JKU)).isNull(); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.JWK)).isNull(); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk.getKeyID()); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5U)).isNull(); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5C)).isNull(); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T)).isNull(); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T_S256)) + .isEqualTo(rsaJwk.getX509CertSHA256Thumbprint().toString()); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.TYP)).isNull(); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.CTY)).isNull(); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.CRIT)).isNull(); + + assertThat(encodedJws.getIssuer()).isEqualTo(jwtClaimsSet.getIssuer()); + assertThat(encodedJws.getSubject()).isEqualTo(jwtClaimsSet.getSubject()); + assertThat(encodedJws.getAudience()).isEqualTo(jwtClaimsSet.getAudience()); + assertThat(encodedJws.getExpiresAt()).isEqualTo(jwtClaimsSet.getExpiresAt()); + assertThat(encodedJws.getNotBefore()).isEqualTo(jwtClaimsSet.getNotBefore()); + assertThat(encodedJws.getIssuedAt()).isEqualTo(jwtClaimsSet.getIssuedAt()); + assertThat(encodedJws.getId()).isEqualTo(jwtClaimsSet.getId()); + assertThat(encodedJws.getClaim("custom-claim-name")).isEqualTo("custom-claim-value"); + + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk.toRSAPublicKey()).build(); + jwtDecoder.decode(encodedJws.getTokenValue()); + } + + @Test + public void encodeWhenKeysRotatedThenNewKeyUsed() throws Exception { + TestJWKSource jwkSource = new TestJWKSource(); + JWKSource jwkSourceDelegate = spy(new JWKSource() { + @Override + public List get(JWKSelector jwkSelector, SecurityContext context) { + return jwkSource.get(jwkSelector, context); + } + }); + NimbusJwsEncoder jwsEncoder = new NimbusJwsEncoder(jwkSourceDelegate); + + JwkListResultCaptor jwkListResultCaptor = new JwkListResultCaptor(); + willAnswer(jwkListResultCaptor).given(jwkSourceDelegate).get(any(), any()); + + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + Jwt encodedJws = jwsEncoder.encode(joseHeader, jwtClaimsSet); + + JWK jwk1 = jwkListResultCaptor.getResult().get(0); + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk1).toRSAPublicKey()).build(); + jwtDecoder.decode(encodedJws.getTokenValue()); + + jwkSource.rotate(); // Simulate key rotation + + encodedJws = jwsEncoder.encode(joseHeader, jwtClaimsSet); + + JWK jwk2 = jwkListResultCaptor.getResult().get(0); + jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk2).toRSAPublicKey()).build(); + jwtDecoder.decode(encodedJws.getTokenValue()); + + assertThat(jwk1.getKeyID()).isNotEqualTo(jwk2.getKeyID()); + } + + private static final class JwkListResultCaptor implements Answer> { + + private List result; + + private List getResult() { + return this.result; + } + + @SuppressWarnings("unchecked") + @Override + public List answer(InvocationOnMock invocationOnMock) throws Throwable { + this.result = (List) invocationOnMock.callRealMethod(); + return this.result; + } + + } + + private static final class TestJWKSource implements JWKSource { + + private int keyId = 1000; + + private JWKSet jwkSet; + + private TestJWKSource() { + init(); + } + + @Override + public List get(JWKSelector jwkSelector, SecurityContext context) { + return jwkSelector.select(this.jwkSet); + } + + private void init() { + // @formatter:off + RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) + .keyID("rsa-jwk-" + this.keyId++) + .build(); + ECKey ecJwk = TestJwks.jwk((ECPublicKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPublic(), (ECPrivateKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPrivate()) + .keyID("ec-jwk-" + this.keyId++) + .build(); + OctetSequenceKey secretJwk = TestJwks.jwk(TestKeys.DEFAULT_SECRET_KEY) + .keyID("secret-jwk-" + this.keyId++) + .build(); + // @formatter:on + this.jwkSet = new JWKSet(Arrays.asList(rsaJwk, ecJwk, secretJwk)); + } + + private void rotate() { + init(); + } + + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizerTests.java new file mode 100644 index 00000000000..d09b12a5430 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizerTests.java @@ -0,0 +1,216 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.Collections; +import java.util.function.Function; + +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.OctetSequenceKey; +import com.nimbusds.jose.jwk.RSAKey; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.jws.MacAlgorithm; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link NimbusJwtClientAuthenticationCustomizer}. + * + * @author Joe Grandja + */ +public class NimbusJwtClientAuthenticationCustomizerTests { + + private Function jwkResolver; + + private NimbusJwtClientAuthenticationCustomizer customizer; + + @Before + public void setup() { + this.jwkResolver = mock(Function.class); + this.customizer = new NimbusJwtClientAuthenticationCustomizer<>(this.jwkResolver); + } + + @Test + public void constructorWhenJwkResolverNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwtClientAuthenticationCustomizer<>(null)) + .withMessage("jwkResolver cannot be null"); + } + + @Test + public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.customizer.setJwtCustomizer(null)) + .withMessage("jwtCustomizer cannot be null"); + } + + @Test + public void customizeWhenAuthorizationGrantRequestNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.customizer.customize(null, new HttpHeaders(), new LinkedMultiValueMap<>())) + .withMessage("authorizationGrantRequest cannot be null"); + } + + @Test + public void customizeWhenHeadersNullThenThrowIllegalArgumentException() { + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + TestClientRegistrations.clientCredentials().build()); + assertThatIllegalArgumentException().isThrownBy( + () -> this.customizer.customize(clientCredentialsGrantRequest, null, new LinkedMultiValueMap<>())) + .withMessage("headers cannot be null"); + } + + @Test + public void customizeWhenParametersNullThenThrowIllegalArgumentException() { + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + TestClientRegistrations.clientCredentials().build()); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.customizer.customize(clientCredentialsGrantRequest, new HttpHeaders(), null)) + .withMessage("parameters cannot be null"); + } + + @Test + public void customizeWhenOtherClientAuthenticationMethodThenNotCustomized() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials() + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .build(); + // @formatter:on + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + this.customizer.customize(clientCredentialsGrantRequest, new HttpHeaders(), new LinkedMultiValueMap<>()); + verifyNoInteractions(this.jwkResolver); + } + + @Test + public void customizeWhenJwkNotResolvedThenThrowOAuth2AuthorizationException() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials() + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.customizer.customize(clientCredentialsGrantRequest, new HttpHeaders(), + new LinkedMultiValueMap<>())) + .withMessage("[invalid_key] Failed to resolve JWK signing key for client registration '" + + clientRegistration.getRegistrationId() + "'."); + } + + @Test + public void customizeWhenPrivateKeyJwtClientAuthenticationMethodThenCustomized() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkResolver.apply(any())).willReturn(rsaJwk); + + // Add custom claim + this.customizer.setJwtCustomizer( + (authorizationGrantRequest, headers, claims) -> claims.put("custom-claim", "custom-value")); + + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials() + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + + this.customizer.customize(clientCredentialsGrantRequest, new HttpHeaders(), parameters); + + assertThat(parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE)) + .isEqualTo("urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); + String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION); + assertThat(encodedJws).isNotNull(); + + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk.toRSAPublicKey()).build(); + Jwt jws = jwtDecoder.decode(encodedJws); + + assertThat(jws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(SignatureAlgorithm.RS256.getName()); + assertThat(jws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk.getKeyID()); + assertThat(jws.getClaim(JwtClaimNames.ISS)).isEqualTo(clientRegistration.getClientId()); + assertThat(jws.getSubject()).isEqualTo(clientRegistration.getClientId()); + assertThat(jws.getAudience()) + .isEqualTo(Collections.singletonList(clientRegistration.getProviderDetails().getTokenUri())); + assertThat(jws.getId()).isNotNull(); + assertThat(jws.getIssuedAt()).isNotNull(); + assertThat(jws.getExpiresAt()).isNotNull(); + assertThat(jws.getClaim("custom-claim")).isEqualTo("custom-value"); + } + + @Test + public void customizeWhenClientSecretJwtClientAuthenticationMethodThenCustomized() { + OctetSequenceKey secretJwk = TestJwks.DEFAULT_SECRET_JWK; + given(this.jwkResolver.apply(any())).willReturn(secretJwk); + + // Add custom claim + this.customizer.setJwtCustomizer( + (authorizationGrantRequest, headers, claims) -> claims.put("custom-claim", "custom-value")); + + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials() + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .build(); + // @formatter:on + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + + this.customizer.customize(clientCredentialsGrantRequest, new HttpHeaders(), parameters); + + assertThat(parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE)) + .isEqualTo("urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); + String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION); + assertThat(encodedJws).isNotNull(); + + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withSecretKey(secretJwk.toSecretKey()).build(); + Jwt jws = jwtDecoder.decode(encodedJws); + + assertThat(jws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(MacAlgorithm.HS256.getName()); + assertThat(jws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(secretJwk.getKeyID()); + assertThat(jws.getClaim(JwtClaimNames.ISS)).isEqualTo(clientRegistration.getClientId()); + assertThat(jws.getSubject()).isEqualTo(clientRegistration.getClientId()); + assertThat(jws.getAudience()) + .isEqualTo(Collections.singletonList(clientRegistration.getProviderDetails().getTokenUri())); + assertThat(jws.getId()).isNotNull(); + assertThat(jws.getIssuedAt()).isNotNull(); + assertThat(jws.getExpiresAt()).isNotNull(); + assertThat(jws.getClaim("custom-claim")).isEqualTo("custom-value"); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java index fb46cfccdee..91a7fe3f75a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,10 @@ package org.springframework.security.oauth2.client.endpoint; -import java.util.Arrays; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; +import org.junit.Before; import org.junit.Test; import org.springframework.http.HttpHeaders; @@ -28,16 +27,23 @@ import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import org.springframework.util.MultiValueMap; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2AuthorizationCodeGrantRequestEntityConverter}. @@ -46,49 +52,40 @@ */ public class OAuth2AuthorizationCodeGrantRequestEntityConverterTests { - private OAuth2AuthorizationCodeGrantRequestEntityConverter converter = new OAuth2AuthorizationCodeGrantRequestEntityConverter(); + private OAuth2AuthorizationCodeGrantRequestEntityConverter converter; - // @formatter:off - private ClientRegistration.Builder clientRegistrationBuilder = ClientRegistration - .withRegistrationId("registration-1") - .clientId("client-1") - .clientSecret("secret") - .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri("https://client.com/callback/client-1") - .scope("read", "write") - .authorizationUri("https://provider.com/oauth2/authorize") - .tokenUri("https://provider.com/oauth2/token") - .userInfoUri("https://provider.com/user") - .userNameAttributeName("id") - .clientName("client-1"); - // @formatter:on + @Before + public void setup() { + this.converter = new OAuth2AuthorizationCodeGrantRequestEntityConverter(); + } - // @formatter:off - private OAuth2AuthorizationRequest.Builder authorizationRequestBuilder = OAuth2AuthorizationRequest - .authorizationCode() - .clientId("client-1") - .state("state-1234") - .authorizationUri("https://provider.com/oauth2/authorize") - .redirectUri("https://client.com/callback/client-1") - .scopes(new HashSet(Arrays.asList("read", "write"))); - // @formatter:on + @Test + public void setCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setCustomizer(null)) + .withMessage("customizer cannot be null"); + } - // @formatter:off - private OAuth2AuthorizationResponse.Builder authorizationResponseBuilder = OAuth2AuthorizationResponse - .success("code-1234") - .state("state-1234") - .redirectUri("https://client.com/callback/client-1"); - // @formatter:on + @Test + public void convertWhenCustomizerSetThenCalled() { + OAuth2AuthorizationGrantRequestEntityConverter.Customizer customizer = mock( + OAuth2AuthorizationGrantRequestEntityConverter.Customizer.class); + this.converter.setCustomizer(customizer); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizationExchange authorizationExchange = TestOAuth2AuthorizationExchanges.success(); + OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest( + clientRegistration, authorizationExchange); + this.converter.convert(authorizationCodeGrantRequest); + verify(customizer).customize(any(OAuth2AuthorizationCodeGrantRequest.class), any(HttpHeaders.class), + any(MultiValueMap.class)); + } @SuppressWarnings("unchecked") @Test public void convertWhenGrantRequestValidThenConverts() { - ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); - OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestBuilder.build(); - OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBuilder.build(); - OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, - authorizationResponse); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizationExchange authorizationExchange = TestOAuth2AuthorizationExchanges.success(); + OAuth2AuthorizationRequest authorizationRequest = authorizationExchange.getAuthorizationRequest(); + OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse(); OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest( clientRegistration, authorizationExchange); RequestEntity requestEntity = this.converter.convert(authorizationCodeGrantRequest); @@ -96,32 +93,32 @@ public void convertWhenGrantRequestValidThenConverts() { assertThat(requestEntity.getUrl().toASCIIString()) .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); - assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); + assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON); assertThat(headers.getContentType()) .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) .isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - assertThat(formParameters.getFirst(OAuth2ParameterNames.CODE)).isEqualTo("code-1234"); + assertThat(formParameters.getFirst(OAuth2ParameterNames.CODE)).isEqualTo(authorizationResponse.getCode()); assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_ID)).isNull(); assertThat(formParameters.getFirst(OAuth2ParameterNames.REDIRECT_URI)) - .isEqualTo(clientRegistration.getRedirectUri()); + .isEqualTo(authorizationRequest.getRedirectUri()); } @SuppressWarnings("unchecked") @Test public void convertWhenPkceGrantRequestValidThenConverts() { - ClientRegistration clientRegistration = this.clientRegistrationBuilder.clientAuthenticationMethod(null) - .clientSecret(null).build(); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .clientAuthenticationMethod(null).clientSecret(null).build(); Map attributes = new HashMap<>(); attributes.put(PkceParameterNames.CODE_VERIFIER, "code-verifier-1234"); Map additionalParameters = new HashMap<>(); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, "code-challenge-1234"); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); - OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestBuilder.attributes(attributes) - .additionalParameters(additionalParameters).build(); - OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBuilder.build(); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .attributes(attributes).additionalParameters(additionalParameters).build(); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().build(); OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest( @@ -131,18 +128,20 @@ public void convertWhenPkceGrantRequestValidThenConverts() { assertThat(requestEntity.getUrl().toASCIIString()) .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); - assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); + assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON); assertThat(headers.getContentType()) .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isNull(); MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) .isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - assertThat(formParameters.getFirst(OAuth2ParameterNames.CODE)).isEqualTo("code-1234"); + assertThat(formParameters.getFirst(OAuth2ParameterNames.CODE)).isEqualTo(authorizationResponse.getCode()); assertThat(formParameters.getFirst(OAuth2ParameterNames.REDIRECT_URI)) - .isEqualTo(clientRegistration.getRedirectUri()); - assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_ID)).isEqualTo("client-1"); - assertThat(formParameters.getFirst(PkceParameterNames.CODE_VERIFIER)).isEqualTo("code-verifier-1234"); + .isEqualTo(authorizationRequest.getRedirectUri()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_ID)) + .isEqualTo(authorizationRequest.getClientId()); + assertThat(formParameters.getFirst(PkceParameterNames.CODE_VERIFIER)) + .isEqualTo(authorizationRequest.getAttribute(PkceParameterNames.CODE_VERIFIER)); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java index 79825c2b99e..03c3476aea9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,12 +24,16 @@ import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.MultiValueMap; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2ClientCredentialsGrantRequestEntityConverter}. @@ -38,42 +42,51 @@ */ public class OAuth2ClientCredentialsGrantRequestEntityConverterTests { - private OAuth2ClientCredentialsGrantRequestEntityConverter converter = new OAuth2ClientCredentialsGrantRequestEntityConverter(); - - private OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest; + private OAuth2ClientCredentialsGrantRequestEntityConverter converter; @Before public void setup() { - // @formatter:off - ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1") - .clientId("client-1") - .clientSecret("secret") - .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .scope("read", "write") - .tokenUri("https://provider.com/oauth2/token") - .build(); - // @formatter:on - this.clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + this.converter = new OAuth2ClientCredentialsGrantRequestEntityConverter(); + } + + @Test + public void setCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setCustomizer(null)) + .withMessage("customizer cannot be null"); + } + + @Test + public void convertWhenCustomizerSetThenCalled() { + OAuth2AuthorizationGrantRequestEntityConverter.Customizer customizer = mock( + OAuth2AuthorizationGrantRequestEntityConverter.Customizer.class); + this.converter.setCustomizer(customizer); + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + this.converter.convert(clientCredentialsGrantRequest); + verify(customizer).customize(any(OAuth2ClientCredentialsGrantRequest.class), any(HttpHeaders.class), + any(MultiValueMap.class)); } @SuppressWarnings("unchecked") @Test public void convertWhenGrantRequestValidThenConverts() { - RequestEntity requestEntity = this.converter.convert(this.clientCredentialsGrantRequest); - ClientRegistration clientRegistration = this.clientCredentialsGrantRequest.getClientRegistration(); + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + RequestEntity requestEntity = this.converter.convert(clientCredentialsGrantRequest); assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); assertThat(requestEntity.getUrl().toASCIIString()) .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); - assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); + assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON); assertThat(headers.getContentType()) .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) .isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()); - assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read write"); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).contains(clientRegistration.getScopes()); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java index 7e85dcc4978..84b5c5aa661 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,10 @@ import org.springframework.util.MultiValueMap; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2PasswordGrantRequestEntityConverter}. @@ -38,31 +42,44 @@ */ public class OAuth2PasswordGrantRequestEntityConverterTests { - private OAuth2PasswordGrantRequestEntityConverter converter = new OAuth2PasswordGrantRequestEntityConverter(); - - private OAuth2PasswordGrantRequest passwordGrantRequest; + private OAuth2PasswordGrantRequestEntityConverter converter; @Before public void setup() { - // @formatter:off - ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() - .authorizationGrantType(AuthorizationGrantType.PASSWORD) - .scope("read", "write") - .build(); - // @formatter:on - this.passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1", "password"); + this.converter = new OAuth2PasswordGrantRequestEntityConverter(); + } + + @Test + public void setCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setCustomizer(null)) + .withMessage("customizer cannot be null"); + } + + @Test + public void convertWhenCustomizerSetThenCalled() { + OAuth2AuthorizationGrantRequestEntityConverter.Customizer customizer = mock( + OAuth2AuthorizationGrantRequestEntityConverter.Customizer.class); + this.converter.setCustomizer(customizer); + ClientRegistration clientRegistration = TestClientRegistrations.password().build(); + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1", + "password"); + this.converter.convert(passwordGrantRequest); + verify(customizer).customize(any(OAuth2PasswordGrantRequest.class), any(HttpHeaders.class), + any(MultiValueMap.class)); } @SuppressWarnings("unchecked") @Test public void convertWhenGrantRequestValidThenConverts() { - RequestEntity requestEntity = this.converter.convert(this.passwordGrantRequest); - ClientRegistration clientRegistration = this.passwordGrantRequest.getClientRegistration(); + ClientRegistration clientRegistration = TestClientRegistrations.password().build(); + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1", + "password"); + RequestEntity requestEntity = this.converter.convert(passwordGrantRequest); assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); assertThat(requestEntity.getUrl().toASCIIString()) .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); - assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); + assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON); assertThat(headers.getContentType()) .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); @@ -71,7 +88,7 @@ public void convertWhenGrantRequestValidThenConverts() { .isEqualTo(AuthorizationGrantType.PASSWORD.getValue()); assertThat(formParameters.getFirst(OAuth2ParameterNames.USERNAME)).isEqualTo("user1"); assertThat(formParameters.getFirst(OAuth2ParameterNames.PASSWORD)).isEqualTo("password"); - assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read write"); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).contains(clientRegistration.getScopes()); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java index 60c53bdeda5..104fe3d9696 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; @@ -35,6 +36,10 @@ import org.springframework.util.MultiValueMap; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2RefreshTokenGrantRequestEntityConverter}. @@ -43,28 +48,48 @@ */ public class OAuth2RefreshTokenGrantRequestEntityConverterTests { - private OAuth2RefreshTokenGrantRequestEntityConverter converter = new OAuth2RefreshTokenGrantRequestEntityConverter(); - - private OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest; + private OAuth2RefreshTokenGrantRequestEntityConverter converter; @Before public void setup() { - this.refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - TestClientRegistrations.clientRegistration().build(), TestOAuth2AccessTokens.scopes("read", "write"), - TestOAuth2RefreshTokens.refreshToken(), Collections.singleton("read")); + this.converter = new OAuth2RefreshTokenGrantRequestEntityConverter(); + } + + @Test + public void setCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setCustomizer(null)) + .withMessage("customizer cannot be null"); + } + + @Test + public void convertWhenCustomizerSetThenCalled() { + OAuth2AuthorizationGrantRequestEntityConverter.Customizer customizer = mock( + OAuth2AuthorizationGrantRequestEntityConverter.Customizer.class); + this.converter.setCustomizer(customizer); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + accessToken, refreshToken); + this.converter.convert(refreshTokenGrantRequest); + verify(customizer).customize(any(OAuth2RefreshTokenGrantRequest.class), any(HttpHeaders.class), + any(MultiValueMap.class)); } @SuppressWarnings("unchecked") @Test public void convertWhenGrantRequestValidThenConverts() { - RequestEntity requestEntity = this.converter.convert(this.refreshTokenGrantRequest); - ClientRegistration clientRegistration = this.refreshTokenGrantRequest.getClientRegistration(); - OAuth2RefreshToken refreshToken = this.refreshTokenGrantRequest.getRefreshToken(); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + accessToken, refreshToken, Collections.singleton("read")); + RequestEntity requestEntity = this.converter.convert(refreshTokenGrantRequest); assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); assertThat(requestEntity.getUrl().toASCIIString()) .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); - assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); + assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON); assertThat(headers.getContentType()) .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJoseHeaders.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJoseHeaders.java new file mode 100644 index 00000000000..9331fd40b5c --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJoseHeaders.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * @author Joe Grandja + */ +final class TestJoseHeaders { + + private TestJoseHeaders() { + } + + static JoseHeader.Builder joseHeader() { + return joseHeader(SignatureAlgorithm.RS256); + } + + static JoseHeader.Builder joseHeader(SignatureAlgorithm signatureAlgorithm) { + // @formatter:off + return JoseHeader.withAlgorithm(signatureAlgorithm) + .jwkSetUri("https://provider.com/oauth2/jwks") + .jwk(rsaJwk()) + .keyId("keyId") + .x509Uri("https://provider.com/oauth2/x509") + .x509CertificateChain(Arrays.asList("x509Cert1", "x509Cert2")) + .x509SHA1Thumbprint("x509SHA1Thumbprint") + .x509SHA256Thumbprint("x509SHA256Thumbprint") + .type("JWT") + .contentType("jwt-content-type") + .header("custom-header-name", "custom-header-value"); + // @formatter:on + } + + private static Map rsaJwk() { + Map rsaJwk = new HashMap<>(); + rsaJwk.put("kty", "RSA"); + rsaJwk.put("n", "modulus"); + rsaJwk.put("e", "exponent"); + return rsaJwk; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJwtClaimsSets.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJwtClaimsSets.java new file mode 100644 index 00000000000..1b311979457 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJwtClaimsSets.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; + +/* + * NOTE: + * This originated in gh-9208 (JwtEncoder), + * which is required to realize the feature in gh-8175 (JWT Client Authentication). + * However, we decided not to merge gh-9208 as part of the 5.5.0 release + * and instead packaged it up privately with the gh-8175 feature. + * We MAY merge gh-9208 in a later release but that is yet to be determined. + * + * gh-9208 Introduce JwtEncoder + * https://github.com/spring-projects/spring-security/pull/9208 + * + * gh-8175 Support JWT for Client Authentication + * https://github.com/spring-projects/spring-security/issues/8175 + */ + +/** + * @author Joe Grandja + */ +final class TestJwtClaimsSets { + + private TestJwtClaimsSets() { + } + + static JwtClaimsSet.Builder jwtClaimsSet() { + String issuer = "https://provider.com"; + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); + + // @formatter:off + return JwtClaimsSet.builder() + .issuer(issuer) + .subject("subject") + .audience(Collections.singletonList("client-1")) + .issuedAt(issuedAt) + .notBefore(issuedAt) + .expiresAt(expiresAt) + .id("jti") + .claim("custom-claim-name", "custom-claim-value"); + // @formatter:on + } + +} diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClientAuthenticationMethod.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClientAuthenticationMethod.java index c4a93a93e6c..7193f90eb63 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClientAuthenticationMethod.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClientAuthenticationMethod.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,6 +58,17 @@ public final class ClientAuthenticationMethod implements Serializable { public static final ClientAuthenticationMethod CLIENT_SECRET_POST = new ClientAuthenticationMethod( "client_secret_post"); + /** + * @since 5.5 + */ + public static final ClientAuthenticationMethod CLIENT_SECRET_JWT = new ClientAuthenticationMethod( + "client_secret_jwt"); + + /** + * @since 5.5 + */ + public static final ClientAuthenticationMethod PRIVATE_KEY_JWT = new ClientAuthenticationMethod("private_key_jwt"); + /** * @since 5.2 */ diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java index e77cd847f44..464f6b1493c 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,6 +48,18 @@ public interface OAuth2ParameterNames { */ String CLIENT_SECRET = "client_secret"; + /** + * {@code client_assertion_type} - used in Access Token Request. + * @since 5.5 + */ + String CLIENT_ASSERTION_TYPE = "client_assertion_type"; + + /** + * {@code client_assertion} - used in Access Token Request. + * @since 5.5 + */ + String CLIENT_ASSERTION = "client_assertion"; + /** * {@code redirect_uri} - used in Authorization Request and Access Token Request. */ diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClientAuthenticationMethodTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClientAuthenticationMethodTests.java index 7a1f2be4f8c..f008049170d 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClientAuthenticationMethodTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClientAuthenticationMethodTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,6 +53,16 @@ public void getValueWhenAuthenticationMethodClientSecretPostThenReturnClientSecr assertThat(ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue()).isEqualTo("client_secret_post"); } + @Test + public void getValueWhenAuthenticationMethodClientSecretJwtThenReturnClientSecretJwt() { + assertThat(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue()).isEqualTo("client_secret_jwt"); + } + + @Test + public void getValueWhenAuthenticationMethodPrivateKeyJwtThenReturnPrivateKeyJwt() { + assertThat(ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue()).isEqualTo("private_key_jwt"); + } + @Test public void getValueWhenAuthenticationMethodNoneThenReturnNone() { assertThat(ClientAuthenticationMethod.NONE.getValue()).isEqualTo("none"); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java new file mode 100644 index 00000000000..412adbfd4d2 --- /dev/null +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jose; + +import java.security.interfaces.ECPrivateKey; +import java.security.interfaces.ECPublicKey; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; + +import javax.crypto.SecretKey; + +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.OctetSequenceKey; +import com.nimbusds.jose.jwk.RSAKey; + +/** + * @author Joe Grandja + */ +public final class TestJwks { + + // @formatter:off + public static final RSAKey DEFAULT_RSA_JWK = + jwk( + TestKeys.DEFAULT_PUBLIC_KEY, + TestKeys.DEFAULT_PRIVATE_KEY + ).build(); + // @formatter:on + + // @formatter:off + public static final ECKey DEFAULT_EC_JWK = + jwk( + (ECPublicKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPublic(), + (ECPrivateKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPrivate() + ).build(); + // @formatter:on + + // @formatter:off + public static final OctetSequenceKey DEFAULT_SECRET_JWK = + jwk( + TestKeys.DEFAULT_SECRET_KEY + ).build(); + // @formatter:on + + private TestJwks() { + } + + public static RSAKey.Builder jwk(RSAPublicKey publicKey, RSAPrivateKey privateKey) { + // @formatter:off + return new RSAKey.Builder(publicKey) + .privateKey(privateKey) + .keyID("rsa-jwk-kid"); + // @formatter:on + } + + public static ECKey.Builder jwk(ECPublicKey publicKey, ECPrivateKey privateKey) { + // @formatter:off + Curve curve = Curve.forECParameterSpec(publicKey.getParams()); + return new ECKey.Builder(curve, publicKey) + .privateKey(privateKey) + .keyID("ec-jwk-kid"); + // @formatter:on + } + + public static OctetSequenceKey.Builder jwk(SecretKey secretKey) { + // @formatter:off + return new OctetSequenceKey.Builder(secretKey) + .keyID("secret-jwk-kid"); + // @formatter:on + } + +} diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java index 7a2b7fb70db..3b11b4504e1 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,17 @@ package org.springframework.security.oauth2.jose; +import java.math.BigInteger; import java.security.KeyFactory; +import java.security.KeyPair; +import java.security.KeyPairGenerator; import java.security.NoSuchAlgorithmException; import java.security.interfaces.RSAPrivateKey; import java.security.interfaces.RSAPublicKey; +import java.security.spec.ECFieldFp; +import java.security.spec.ECParameterSpec; +import java.security.spec.ECPoint; +import java.security.spec.EllipticCurve; import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; import java.security.spec.X509EncodedKeySpec; @@ -109,6 +116,34 @@ public final class TestKeys { } } + public static final KeyPair DEFAULT_RSA_KEY_PAIR = new KeyPair(DEFAULT_PUBLIC_KEY, DEFAULT_PRIVATE_KEY); + + public static final KeyPair DEFAULT_EC_KEY_PAIR = generateEcKeyPair(); + + static KeyPair generateEcKeyPair() { + EllipticCurve ellipticCurve = new EllipticCurve( + new ECFieldFp(new BigInteger( + "115792089210356248762697446949407573530086143415290314195533631308867097853951")), + new BigInteger("115792089210356248762697446949407573530086143415290314195533631308867097853948"), + new BigInteger("41058363725152142129326129780047268409114441015993725554835256314039467401291")); + ECPoint ecPoint = new ECPoint( + new BigInteger("48439561293906451759052585252797914202762949526041747995844080717082404635286"), + new BigInteger("36134250956749795798585127919587881956611106672985015071877198253568414405109")); + ECParameterSpec ecParameterSpec = new ECParameterSpec(ellipticCurve, ecPoint, + new BigInteger("115792089210356248762697446949407573529996955224135760342422259061068512044369"), 1); + + KeyPair keyPair; + try { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC"); + keyPairGenerator.initialize(ecParameterSpec); + keyPair = keyPairGenerator.generateKeyPair(); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + return keyPair; + } + private TestKeys() { } From 2ee0a389f9dce07a0b394a00cd66ab032eda75cb Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 6 Apr 2021 05:21:00 -0400 Subject: [PATCH 2/4] Add AbstractOAuth2AuthorizationGrantRequestEntityConverter --- ...horizationGrantRequestEntityConverter.java | 156 ++++++++++++++++++ ...entAuthenticationParametersConverter.java} | 31 ++-- ...zationCodeGrantRequestEntityConverter.java | 58 +------ ...horizationGrantRequestEntityConverter.java | 61 ------- ...redentialsGrantRequestEntityConverter.java | 58 +------ ...h2PasswordGrantRequestEntityConverter.java | 54 +----- ...freshTokenGrantRequestEntityConverter.java | 56 +------ ...orizationCodeTokenResponseClientTests.java | 14 +- ...ntCredentialsTokenResponseClientTests.java | 14 +- ...faultPasswordTokenResponseClientTests.java | 14 +- ...tRefreshTokenTokenResponseClientTests.java | 14 +- ...thenticationParametersConverterTests.java} | 63 +++---- ...nCodeGrantRequestEntityConverterTests.java | 60 +++++-- ...tialsGrantRequestEntityConverterTests.java | 59 +++++-- ...swordGrantRequestEntityConverterTests.java | 59 +++++-- ...TokenGrantRequestEntityConverterTests.java | 61 +++++-- 16 files changed, 443 insertions(+), 389 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequestEntityConverter.java rename oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/{NimbusJwtClientAuthenticationCustomizer.java => NimbusJwtClientAuthenticationParametersConverter.java} (87%) delete mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityConverter.java rename oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/{NimbusJwtClientAuthenticationCustomizerTests.java => NimbusJwtClientAuthenticationParametersConverterTests.java} (72%) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequestEntityConverter.java new file mode 100644 index 00000000000..0891e413b30 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequestEntityConverter.java @@ -0,0 +1,156 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.net.URI; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.RequestEntity; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * Base implementation of a {@link Converter} that converts the provided + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link RequestEntity} + * representation of an OAuth 2.0 Access Token Request for the Authorization Grant. + * + * @param the type of {@link AbstractOAuth2AuthorizationGrantRequest} + * @author Joe Grandja + * @since 5.5 + * @see Converter + * @see AbstractOAuth2AuthorizationGrantRequest + * @see RequestEntity + */ +public abstract class AbstractOAuth2AuthorizationGrantRequestEntityConverter + implements Converter> { + + // @formatter:off + protected Converter headersConverter = + (authorizationGrantRequest) -> OAuth2AuthorizationGrantRequestEntityUtils + .getTokenRequestHeaders(authorizationGrantRequest.getClientRegistration()); + // @formatter:on + + protected Converter> parametersConverter = this::createParameters; + + /** + * Sub-class constructor. + */ + protected AbstractOAuth2AuthorizationGrantRequestEntityConverter() { + } + + @Override + public RequestEntity convert(T authorizationGrantRequest) { + HttpHeaders headers = this.headersConverter.convert(authorizationGrantRequest); + MultiValueMap parameters = this.parametersConverter.convert(authorizationGrantRequest); + URI uri = UriComponentsBuilder + .fromUriString(authorizationGrantRequest.getClientRegistration().getProviderDetails().getTokenUri()) + .build().toUri(); + return new RequestEntity<>(parameters, headers, HttpMethod.POST, uri); + } + + /** + * Returns a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access + * Token Request body. + * @param authorizationGrantRequest the authorization grant request + * @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access + * Token Request body + */ + protected abstract MultiValueMap createParameters(T authorizationGrantRequest); + + /** + * Sets the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders} + * used in the OAuth 2.0 Access Token Request headers. + * @param headersConverter the {@link Converter} used for converting the + * {@link OAuth2AuthorizationCodeGrantRequest} to {@link HttpHeaders} + */ + public final void setHeadersConverter(Converter headersConverter) { + Assert.notNull(headersConverter, "headersConverter cannot be null"); + this.headersConverter = headersConverter; + } + + /** + * Add (compose) the provided {@code headersConverter} to the current + * {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders} + * used in the OAuth 2.0 Access Token Request headers. + * @param headersConverter the {@link Converter} to add (compose) to the current + * {@link Converter} used for converting the + * {@link OAuth2AuthorizationCodeGrantRequest} to a {@link HttpHeaders} + */ + public final void addHeadersConverter(Converter headersConverter) { + Assert.notNull(headersConverter, "headersConverter cannot be null"); + Converter currentHeadersConverter = this.headersConverter; + this.headersConverter = (authorizationGrantRequest) -> { + // Append headers using a Composite Converter + HttpHeaders headers = currentHeadersConverter.convert(authorizationGrantRequest); + if (headers == null) { + headers = new HttpHeaders(); + } + HttpHeaders headersToAdd = headersConverter.convert(authorizationGrantRequest); + if (headersToAdd != null) { + headers.addAll(headersToAdd); + } + return headers; + }; + } + + /** + * Sets the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} + * of the parameters used in the OAuth 2.0 Access Token Request body. + * @param parametersConverter the {@link Converter} used for converting the + * {@link OAuth2AuthorizationCodeGrantRequest} to a {@link MultiValueMap} of the + * parameters + */ + public final void setParametersConverter(Converter> parametersConverter) { + Assert.notNull(parametersConverter, "parametersConverter cannot be null"); + this.parametersConverter = parametersConverter; + } + + /** + * Add (compose) the provided {@code parametersConverter} to the current + * {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} + * of the parameters used in the OAuth 2.0 Access Token Request body. + * @param parametersConverter the {@link Converter} to add (compose) to the current + * {@link Converter} used for converting the + * {@link OAuth2AuthorizationCodeGrantRequest} to a {@link MultiValueMap} of the + * parameters + */ + public final void addParametersConverter(Converter> parametersConverter) { + Assert.notNull(parametersConverter, "parametersConverter cannot be null"); + Converter> currentParametersConverter = this.parametersConverter; + this.parametersConverter = (authorizationGrantRequest) -> { + // Append parameters using a Composite Converter + MultiValueMap parameters = currentParametersConverter.convert(authorizationGrantRequest); + if (parameters == null) { + parameters = new LinkedMultiValueMap<>(); + } + MultiValueMap parametersToAdd = parametersConverter.convert(authorizationGrantRequest); + if (parametersToAdd != null) { + parameters.addAll(parametersToAdd); + } + return parameters; + }; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizer.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java similarity index 87% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizer.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index e256c0c2645..aabaa92cd71 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizer.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -31,7 +31,7 @@ import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; -import org.springframework.http.HttpHeaders; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; @@ -42,15 +42,15 @@ import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; /** - * An implementation of an - * {@link OAuth2AuthorizationGrantRequestEntityConverter.Customizer} that customizes the - * OAuth 2.0 Access Token Request by adding a signed JSON Web Token (JWS) to be used for - * client authentication at the Authorization Server's Token Endpoint. The private/secret - * key used for signing the JWS is supplied by the {@code com.nimbusds.jose.jwk.JWK} - * resolver provided via the constructor. + * A {@link Converter} that customizes the OAuth 2.0 Access Token Request parameters by + * adding a signed JSON Web Token (JWS) to be used for client authentication at the + * Authorization Server's Token Endpoint. The private/secret key used for signing the JWS + * is supplied by the {@code com.nimbusds.jose.jwk.JWK} resolver provided via the + * constructor. * *

* NOTE: This implementation uses the Nimbus JOSE + JWT SDK. @@ -58,7 +58,7 @@ * @param the type of {@link AbstractOAuth2AuthorizationGrantRequest} * @author Joe Grandja * @since 5.5 - * @see OAuth2AuthorizationGrantRequestEntityConverter.Customizer + * @see Converter * @see com.nimbusds.jose.jwk.JWK * @see JwtCustomizer * @see 2.2 @@ -68,8 +68,8 @@ * @see Nimbus * JOSE + JWT SDK */ -public final class NimbusJwtClientAuthenticationCustomizer - implements OAuth2AuthorizationGrantRequestEntityConverter.Customizer { +public final class NimbusJwtClientAuthenticationParametersConverter + implements Converter> { private static final String INVALID_KEY_ERROR_CODE = "invalid_key"; @@ -90,22 +90,20 @@ public final class NimbusJwtClientAuthenticationCustomizer jwkResolver) { + public NimbusJwtClientAuthenticationParametersConverter(Function jwkResolver) { Assert.notNull(jwkResolver, "jwkResolver cannot be null"); this.jwkResolver = jwkResolver; } @Override - public void customize(T authorizationGrantRequest, HttpHeaders headers, MultiValueMap parameters) { + public MultiValueMap convert(T authorizationGrantRequest) { Assert.notNull(authorizationGrantRequest, "authorizationGrantRequest cannot be null"); - Assert.notNull(headers, "headers cannot be null"); - Assert.notNull(parameters, "parameters cannot be null"); ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); if (!ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals(clientRegistration.getClientAuthenticationMethod()) && !ClientAuthenticationMethod.CLIENT_SECRET_JWT .equals(clientRegistration.getClientAuthenticationMethod())) { - return; + return null; } JWK jwk = this.jwkResolver.apply(clientRegistration); @@ -154,8 +152,11 @@ public void customize(T authorizationGrantRequest, HttpHeaders headers, MultiVal Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet); + MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, CLIENT_ASSERTION_TYPE_VALUE); parameters.set(OAuth2ParameterNames.CLIENT_ASSERTION, jws.getTokenValue()); + + return parameters; } /** diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java index 65ec6b28fc3..364a4ac591e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java @@ -16,76 +16,32 @@ package org.springframework.security.oauth2.client.endpoint; -import java.net.URI; - -import org.springframework.core.convert.converter.Converter; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; -import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import org.springframework.web.util.UriComponentsBuilder; /** - * A {@link Converter} that converts the provided - * {@link OAuth2AuthorizationCodeGrantRequest} to a {@link RequestEntity} representation - * of an OAuth 2.0 Access Token Request for the Authorization Code Grant. + * An implementation of an {@link AbstractOAuth2AuthorizationGrantRequestEntityConverter} + * that converts the provided {@link OAuth2AuthorizationCodeGrantRequest} to a + * {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the + * Authorization Code Grant. * * @author Joe Grandja * @since 5.1 - * @see OAuth2AuthorizationGrantRequestEntityConverter + * @see AbstractOAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2AuthorizationCodeGrantRequest * @see RequestEntity */ public class OAuth2AuthorizationCodeGrantRequestEntityConverter - implements OAuth2AuthorizationGrantRequestEntityConverter { - - private Customizer customizer = (request, headers, parameters) -> { - }; + extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { - /** - * Returns the {@link RequestEntity} used for the Access Token Request. - * @param authorizationCodeGrantRequest the authorization code grant request - * @return the {@link RequestEntity} used for the Access Token Request - */ @Override - public RequestEntity convert(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { - ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); - MultiValueMap parameters = createParameters(authorizationCodeGrantRequest); - this.customizer.customize(authorizationCodeGrantRequest, headers, parameters); - URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() - .toUri(); - return new RequestEntity<>(parameters, headers, HttpMethod.POST, uri); - } - - /** - * Sets the {@link Customizer} to be provided the opportunity to customize the - * {@link HttpHeaders headers} and/or {@link MultiValueMap parameters} of the OAuth - * 2.0 Access Token Request. - * @param customizer the {@link Customizer} to be provided the opportunity to - * customize the OAuth 2.0 Access Token Request - * @since 5.5 - */ - public final void setCustomizer(Customizer customizer) { - Assert.notNull(customizer, "customizer cannot be null"); - this.customizer = customizer; - } - - /** - * Returns a {@link MultiValueMap} of the form parameters used for the Access Token - * Request body. - * @param authorizationCodeGrantRequest the authorization code grant request - * @return a {@link MultiValueMap} of the form parameters used for the Access Token - * Request body - */ - private MultiValueMap createParameters( + protected MultiValueMap createParameters( OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration(); OAuth2AuthorizationExchange authorizationExchange = authorizationCodeGrantRequest.getAuthorizationExchange(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityConverter.java deleted file mode 100644 index 2f70c79e04f..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityConverter.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.endpoint; - -import org.springframework.core.convert.converter.Converter; -import org.springframework.http.HttpHeaders; -import org.springframework.http.RequestEntity; -import org.springframework.util.MultiValueMap; - -/** - * Implementations of this interface are responsible for {@link Converter#convert(Object) - * converting} the provided {@link AbstractOAuth2AuthorizationGrantRequest authorization - * grant credential} to a {@link RequestEntity} representation of an OAuth 2.0 Access - * Token Request. - * - * @author Joe Grandja - * @since 5.5 - * @see Converter - * @see AbstractOAuth2AuthorizationGrantRequest - * @see RequestEntity - * @param the type of {@link AbstractOAuth2AuthorizationGrantRequest} - */ -@FunctionalInterface -public interface OAuth2AuthorizationGrantRequestEntityConverter - extends Converter> { - - /** - * Implementations of this interface are provided the opportunity to customize the - * {@link RequestEntity} representation of the OAuth 2.0 Access Token Request. - * - * @param the type of {@link AbstractOAuth2AuthorizationGrantRequest} - */ - @FunctionalInterface - interface Customizer { - - /** - * Customize the {@link HttpHeaders headers} and/or {@link MultiValueMap - * parameters} of the OAuth 2.0 Access Token Request. - * @param authorizationGrantRequest the authorization grant request - * @param headers the headers - * @param parameters the parameters - */ - void customize(T authorizationGrantRequest, HttpHeaders headers, MultiValueMap parameters); - - } - -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java index ade884e7b94..14b51138dfd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java @@ -16,76 +16,32 @@ package org.springframework.security.oauth2.client.endpoint; -import java.net.URI; - -import org.springframework.core.convert.converter.Converter; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; -import org.springframework.web.util.UriComponentsBuilder; /** - * A {@link Converter} that converts the provided - * {@link OAuth2ClientCredentialsGrantRequest} to a {@link RequestEntity} representation - * of an OAuth 2.0 Access Token Request for the Client Credentials Grant. + * An implementation of an {@link AbstractOAuth2AuthorizationGrantRequestEntityConverter} + * that converts the provided {@link OAuth2ClientCredentialsGrantRequest} to a + * {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the + * Client Credentials Grant. * * @author Joe Grandja * @since 5.1 - * @see OAuth2AuthorizationGrantRequestEntityConverter + * @see AbstractOAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2ClientCredentialsGrantRequest * @see RequestEntity */ public class OAuth2ClientCredentialsGrantRequestEntityConverter - implements OAuth2AuthorizationGrantRequestEntityConverter { - - private Customizer customizer = (request, headers, parameters) -> { - }; + extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { - /** - * Returns the {@link RequestEntity} used for the Access Token Request. - * @param clientCredentialsGrantRequest the client credentials grant request - * @return the {@link RequestEntity} used for the Access Token Request - */ @Override - public RequestEntity convert(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { - ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); - MultiValueMap parameters = createParameters(clientCredentialsGrantRequest); - this.customizer.customize(clientCredentialsGrantRequest, headers, parameters); - URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() - .toUri(); - return new RequestEntity<>(parameters, headers, HttpMethod.POST, uri); - } - - /** - * Sets the {@link Customizer} to be provided the opportunity to customize the - * {@link HttpHeaders headers} and/or {@link MultiValueMap parameters} of the OAuth - * 2.0 Access Token Request. - * @param customizer the {@link Customizer} to be provided the opportunity to - * customize the OAuth 2.0 Access Token Request - * @since 5.5 - */ - public final void setCustomizer(Customizer customizer) { - Assert.notNull(customizer, "customizer cannot be null"); - this.customizer = customizer; - } - - /** - * Returns a {@link MultiValueMap} of the form parameters used for the Access Token - * Request body. - * @param clientCredentialsGrantRequest the client credentials grant request - * @return a {@link MultiValueMap} of the form parameters used for the Access Token - * Request body - */ - private MultiValueMap createParameters( + protected MultiValueMap createParameters( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); MultiValueMap parameters = new LinkedMultiValueMap<>(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java index aee6d57f6a9..34dc96479be 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java @@ -16,76 +16,32 @@ package org.springframework.security.oauth2.client.endpoint; -import java.net.URI; - -import org.springframework.core.convert.converter.Converter; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; -import org.springframework.web.util.UriComponentsBuilder; /** - * A {@link Converter} that converts the provided {@link OAuth2PasswordGrantRequest} to a + * An implementation of an {@link AbstractOAuth2AuthorizationGrantRequestEntityConverter} + * that converts the provided {@link OAuth2PasswordGrantRequest} to a * {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the * Resource Owner Password Credentials Grant. * * @author Joe Grandja * @since 5.2 - * @see OAuth2AuthorizationGrantRequestEntityConverter + * @see AbstractOAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2PasswordGrantRequest * @see RequestEntity */ public class OAuth2PasswordGrantRequestEntityConverter - implements OAuth2AuthorizationGrantRequestEntityConverter { - - private Customizer customizer = (request, headers, parameters) -> { - }; + extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { - /** - * Returns the {@link RequestEntity} used for the Access Token Request. - * @param passwordGrantRequest the password grant request - * @return the {@link RequestEntity} used for the Access Token Request - */ @Override - public RequestEntity convert(OAuth2PasswordGrantRequest passwordGrantRequest) { - ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); - MultiValueMap parameters = createParameters(passwordGrantRequest); - this.customizer.customize(passwordGrantRequest, headers, parameters); - URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() - .toUri(); - return new RequestEntity<>(parameters, headers, HttpMethod.POST, uri); - } - - /** - * Sets the {@link Customizer} to be provided the opportunity to customize the - * {@link HttpHeaders headers} and/or {@link MultiValueMap parameters} of the OAuth - * 2.0 Access Token Request. - * @param customizer the {@link Customizer} to be provided the opportunity to - * customize the OAuth 2.0 Access Token Request - * @since 5.5 - */ - public final void setCustomizer(Customizer customizer) { - Assert.notNull(customizer, "customizer cannot be null"); - this.customizer = customizer; - } - - /** - * Returns a {@link MultiValueMap} of the form parameters used for the Access Token - * Request body. - * @param passwordGrantRequest the password grant request - * @return a {@link MultiValueMap} of the form parameters used for the Access Token - * Request body - */ - private MultiValueMap createParameters(OAuth2PasswordGrantRequest passwordGrantRequest) { + protected MultiValueMap createParameters(OAuth2PasswordGrantRequest passwordGrantRequest) { ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add(OAuth2ParameterNames.GRANT_TYPE, passwordGrantRequest.getGrantType().getValue()); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java index 3c8fdae08b0..e98652ae544 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java @@ -16,76 +16,32 @@ package org.springframework.security.oauth2.client.endpoint; -import java.net.URI; - -import org.springframework.core.convert.converter.Converter; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; -import org.springframework.web.util.UriComponentsBuilder; /** - * A {@link Converter} that converts the provided {@link OAuth2RefreshTokenGrantRequest} - * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the + * An implementation of an {@link AbstractOAuth2AuthorizationGrantRequestEntityConverter} + * that converts the provided {@link OAuth2RefreshTokenGrantRequest} to a + * {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the * Refresh Token Grant. * * @author Joe Grandja * @since 5.2 - * @see OAuth2AuthorizationGrantRequestEntityConverter + * @see AbstractOAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2RefreshTokenGrantRequest * @see RequestEntity */ public class OAuth2RefreshTokenGrantRequestEntityConverter - implements OAuth2AuthorizationGrantRequestEntityConverter { - - private Customizer customizer = (request, headers, parameters) -> { - }; + extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { - /** - * Returns the {@link RequestEntity} used for the Access Token Request. - * @param refreshTokenGrantRequest the refresh token grant request - * @return the {@link RequestEntity} used for the Access Token Request - */ @Override - public RequestEntity convert(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { - ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); - MultiValueMap parameters = createParameters(refreshTokenGrantRequest); - this.customizer.customize(refreshTokenGrantRequest, headers, parameters); - URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() - .toUri(); - return new RequestEntity<>(parameters, headers, HttpMethod.POST, uri); - } - - /** - * Sets the {@link Customizer} to be provided the opportunity to customize the - * {@link HttpHeaders headers} and/or {@link MultiValueMap parameters} of the OAuth - * 2.0 Access Token Request. - * @param customizer the {@link Customizer} to be provided the opportunity to - * customize the OAuth 2.0 Access Token Request - * @since 5.5 - */ - public final void setCustomizer(Customizer customizer) { - Assert.notNull(customizer, "customizer cannot be null"); - this.customizer = customizer; - } - - /** - * Returns a {@link MultiValueMap} of the form parameters used for the Access Token - * Request body. - * @param refreshTokenGrantRequest the refresh token grant request - * @return a {@link MultiValueMap} of the form parameters used for the Access Token - * Request body - */ - private MultiValueMap createParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + protected MultiValueMap createParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue()); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java index 07cc2621170..cf6f86d2032 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java @@ -188,12 +188,12 @@ public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersA .build(); // @formatter:on - // Configure Jwt client authentication customizer + // Configure Jwt client authentication converter SecretKeySpec secretKey = new SecretKeySpec( clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); JWK jwk = TestJwks.jwk(secretKey).build(); Function jwkResolver = (registration) -> jwk; - configureJwtClientAuthenticationCustomizer(jwkResolver); + configureJwtClientAuthenticationConverter(jwkResolver); this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration)); RecordedRequest recordedRequest = this.server.takeRequest(); @@ -221,10 +221,10 @@ public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAre .build(); // @formatter:on - // Configure Jwt client authentication customizer + // Configure Jwt client authentication converter JWK jwk = TestJwks.DEFAULT_RSA_JWK; Function jwkResolver = (registration) -> jwk; - configureJwtClientAuthenticationCustomizer(jwkResolver); + configureJwtClientAuthenticationConverter(jwkResolver); this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration)); RecordedRequest recordedRequest = this.server.takeRequest(); @@ -235,11 +235,11 @@ public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAre assertThat(formParameters).contains("client_assertion="); } - private void configureJwtClientAuthenticationCustomizer(Function jwkResolver) { - NimbusJwtClientAuthenticationCustomizer jwtClientAuthenticationCustomizer = new NimbusJwtClientAuthenticationCustomizer<>( + private void configureJwtClientAuthenticationConverter(Function jwkResolver) { + NimbusJwtClientAuthenticationParametersConverter jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>( jwkResolver); OAuth2AuthorizationCodeGrantRequestEntityConverter requestEntityConverter = new OAuth2AuthorizationCodeGrantRequestEntityConverter(); - requestEntityConverter.setCustomizer(jwtClientAuthenticationCustomizer); + requestEntityConverter.addParametersConverter(jwtClientAuthenticationConverter); this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java index 2ce24fe9c9b..798bb901ca0 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java @@ -194,12 +194,12 @@ public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersA .build(); // @formatter:on - // Configure Jwt client authentication customizer + // Configure Jwt client authentication converter SecretKeySpec secretKey = new SecretKeySpec( clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); JWK jwk = TestJwks.jwk(secretKey).build(); Function jwkResolver = (registration) -> jwk; - configureJwtClientAuthenticationCustomizer(jwkResolver); + configureJwtClientAuthenticationConverter(jwkResolver); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( clientRegistration); @@ -229,10 +229,10 @@ public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAre .build(); // @formatter:on - // Configure Jwt client authentication customizer + // Configure Jwt client authentication converter JWK jwk = TestJwks.DEFAULT_RSA_JWK; Function jwkResolver = (registration) -> jwk; - configureJwtClientAuthenticationCustomizer(jwkResolver); + configureJwtClientAuthenticationConverter(jwkResolver); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( clientRegistration); @@ -245,11 +245,11 @@ public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAre assertThat(formParameters).contains("client_assertion="); } - private void configureJwtClientAuthenticationCustomizer(Function jwkResolver) { - NimbusJwtClientAuthenticationCustomizer jwtClientAuthenticationCustomizer = new NimbusJwtClientAuthenticationCustomizer<>( + private void configureJwtClientAuthenticationConverter(Function jwkResolver) { + NimbusJwtClientAuthenticationParametersConverter jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>( jwkResolver); OAuth2ClientCredentialsGrantRequestEntityConverter requestEntityConverter = new OAuth2ClientCredentialsGrantRequestEntityConverter(); - requestEntityConverter.setCustomizer(jwtClientAuthenticationCustomizer); + requestEntityConverter.addParametersConverter(jwtClientAuthenticationConverter); this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java index 8fb756a8112..43c02e7492d 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java @@ -170,12 +170,12 @@ public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersA .build(); // @formatter:on - // Configure Jwt client authentication customizer + // Configure Jwt client authentication converter SecretKeySpec secretKey = new SecretKeySpec( clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); JWK jwk = TestJwks.jwk(secretKey).build(); Function jwkResolver = (registration) -> jwk; - configureJwtClientAuthenticationCustomizer(jwkResolver); + configureJwtClientAuthenticationConverter(jwkResolver); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, this.username, this.password); @@ -205,10 +205,10 @@ public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAre .build(); // @formatter:on - // Configure Jwt client authentication customizer + // Configure Jwt client authentication converter JWK jwk = TestJwks.DEFAULT_RSA_JWK; Function jwkResolver = (registration) -> jwk; - configureJwtClientAuthenticationCustomizer(jwkResolver); + configureJwtClientAuthenticationConverter(jwkResolver); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, this.username, this.password); @@ -221,11 +221,11 @@ public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAre assertThat(formParameters).contains("client_assertion="); } - private void configureJwtClientAuthenticationCustomizer(Function jwkResolver) { - NimbusJwtClientAuthenticationCustomizer jwtClientAuthenticationCustomizer = new NimbusJwtClientAuthenticationCustomizer<>( + private void configureJwtClientAuthenticationConverter(Function jwkResolver) { + NimbusJwtClientAuthenticationParametersConverter jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>( jwkResolver); OAuth2PasswordGrantRequestEntityConverter requestEntityConverter = new OAuth2PasswordGrantRequestEntityConverter(); - requestEntityConverter.setCustomizer(jwtClientAuthenticationCustomizer); + requestEntityConverter.addParametersConverter(jwtClientAuthenticationConverter); this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java index d5f1413d453..e7c0aed72f0 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java @@ -171,12 +171,12 @@ public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersA .build(); // @formatter:on - // Configure Jwt client authentication customizer + // Configure Jwt client authentication converter SecretKeySpec secretKey = new SecretKeySpec( clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); JWK jwk = TestJwks.jwk(secretKey).build(); Function jwkResolver = (registration) -> jwk; - configureJwtClientAuthenticationCustomizer(jwkResolver); + configureJwtClientAuthenticationConverter(jwkResolver); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); @@ -206,10 +206,10 @@ public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAre .build(); // @formatter:on - // Configure Jwt client authentication customizer + // Configure Jwt client authentication converter JWK jwk = TestJwks.DEFAULT_RSA_JWK; Function jwkResolver = (registration) -> jwk; - configureJwtClientAuthenticationCustomizer(jwkResolver); + configureJwtClientAuthenticationConverter(jwkResolver); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); @@ -222,11 +222,11 @@ public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAre assertThat(formParameters).contains("client_assertion="); } - private void configureJwtClientAuthenticationCustomizer(Function jwkResolver) { - NimbusJwtClientAuthenticationCustomizer jwtClientAuthenticationCustomizer = new NimbusJwtClientAuthenticationCustomizer<>( + private void configureJwtClientAuthenticationConverter(Function jwkResolver) { + NimbusJwtClientAuthenticationParametersConverter jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>( jwkResolver); OAuth2RefreshTokenGrantRequestEntityConverter requestEntityConverter = new OAuth2RefreshTokenGrantRequestEntityConverter(); - requestEntityConverter.setCustomizer(jwtClientAuthenticationCustomizer); + requestEntityConverter.addParametersConverter(jwtClientAuthenticationConverter); this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java similarity index 72% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizerTests.java rename to oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java index d09b12a5430..234b436151d 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationCustomizerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java @@ -25,7 +25,6 @@ import org.junit.Before; import org.junit.Test; -import org.springframework.http.HttpHeaders; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -37,7 +36,6 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; -import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import static org.assertj.core.api.Assertions.assertThat; @@ -49,61 +47,43 @@ import static org.mockito.Mockito.verifyNoInteractions; /** - * Tests for {@link NimbusJwtClientAuthenticationCustomizer}. + * Tests for {@link NimbusJwtClientAuthenticationParametersConverter}. * * @author Joe Grandja */ -public class NimbusJwtClientAuthenticationCustomizerTests { +public class NimbusJwtClientAuthenticationParametersConverterTests { private Function jwkResolver; - private NimbusJwtClientAuthenticationCustomizer customizer; + private NimbusJwtClientAuthenticationParametersConverter converter; @Before public void setup() { this.jwkResolver = mock(Function.class); - this.customizer = new NimbusJwtClientAuthenticationCustomizer<>(this.jwkResolver); + this.converter = new NimbusJwtClientAuthenticationParametersConverter<>(this.jwkResolver); } @Test public void constructorWhenJwkResolverNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwtClientAuthenticationCustomizer<>(null)) + assertThatIllegalArgumentException() + .isThrownBy(() -> new NimbusJwtClientAuthenticationParametersConverter<>(null)) .withMessage("jwkResolver cannot be null"); } @Test public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.customizer.setJwtCustomizer(null)) + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setJwtCustomizer(null)) .withMessage("jwtCustomizer cannot be null"); } @Test - public void customizeWhenAuthorizationGrantRequestNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.customizer.customize(null, new HttpHeaders(), new LinkedMultiValueMap<>())) + public void convertWhenAuthorizationGrantRequestNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.convert(null)) .withMessage("authorizationGrantRequest cannot be null"); } @Test - public void customizeWhenHeadersNullThenThrowIllegalArgumentException() { - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - TestClientRegistrations.clientCredentials().build()); - assertThatIllegalArgumentException().isThrownBy( - () -> this.customizer.customize(clientCredentialsGrantRequest, null, new LinkedMultiValueMap<>())) - .withMessage("headers cannot be null"); - } - - @Test - public void customizeWhenParametersNullThenThrowIllegalArgumentException() { - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( - TestClientRegistrations.clientCredentials().build()); - assertThatIllegalArgumentException() - .isThrownBy(() -> this.customizer.customize(clientCredentialsGrantRequest, new HttpHeaders(), null)) - .withMessage("parameters cannot be null"); - } - - @Test - public void customizeWhenOtherClientAuthenticationMethodThenNotCustomized() { + public void convertWhenOtherClientAuthenticationMethodThenNotCustomized() { // @formatter:off ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials() .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) @@ -111,12 +91,12 @@ public void customizeWhenOtherClientAuthenticationMethodThenNotCustomized() { // @formatter:on OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( clientRegistration); - this.customizer.customize(clientCredentialsGrantRequest, new HttpHeaders(), new LinkedMultiValueMap<>()); + assertThat(this.converter.convert(clientCredentialsGrantRequest)).isNull(); verifyNoInteractions(this.jwkResolver); } @Test - public void customizeWhenJwkNotResolvedThenThrowOAuth2AuthorizationException() { + public void convertWhenJwkNotResolvedThenThrowOAuth2AuthorizationException() { // @formatter:off ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials() .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) @@ -125,19 +105,18 @@ public void customizeWhenJwkNotResolvedThenThrowOAuth2AuthorizationException() { OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( clientRegistration); assertThatExceptionOfType(OAuth2AuthorizationException.class) - .isThrownBy(() -> this.customizer.customize(clientCredentialsGrantRequest, new HttpHeaders(), - new LinkedMultiValueMap<>())) + .isThrownBy(() -> this.converter.convert(clientCredentialsGrantRequest)) .withMessage("[invalid_key] Failed to resolve JWK signing key for client registration '" + clientRegistration.getRegistrationId() + "'."); } @Test - public void customizeWhenPrivateKeyJwtClientAuthenticationMethodThenCustomized() throws Exception { + public void convertWhenPrivateKeyJwtClientAuthenticationMethodThenCustomized() throws Exception { RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; given(this.jwkResolver.apply(any())).willReturn(rsaJwk); // Add custom claim - this.customizer.setJwtCustomizer( + this.converter.setJwtCustomizer( (authorizationGrantRequest, headers, claims) -> claims.put("custom-claim", "custom-value")); // @formatter:off @@ -148,9 +127,7 @@ public void customizeWhenPrivateKeyJwtClientAuthenticationMethodThenCustomized() OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( clientRegistration); - MultiValueMap parameters = new LinkedMultiValueMap<>(); - - this.customizer.customize(clientCredentialsGrantRequest, new HttpHeaders(), parameters); + MultiValueMap parameters = this.converter.convert(clientCredentialsGrantRequest); assertThat(parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE)) .isEqualTo("urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); @@ -173,12 +150,12 @@ public void customizeWhenPrivateKeyJwtClientAuthenticationMethodThenCustomized() } @Test - public void customizeWhenClientSecretJwtClientAuthenticationMethodThenCustomized() { + public void convertWhenClientSecretJwtClientAuthenticationMethodThenCustomized() { OctetSequenceKey secretJwk = TestJwks.DEFAULT_SECRET_JWK; given(this.jwkResolver.apply(any())).willReturn(secretJwk); // Add custom claim - this.customizer.setJwtCustomizer( + this.converter.setJwtCustomizer( (authorizationGrantRequest, headers, claims) -> claims.put("custom-claim", "custom-value")); // @formatter:off @@ -189,9 +166,7 @@ public void customizeWhenClientSecretJwtClientAuthenticationMethodThenCustomized OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( clientRegistration); - MultiValueMap parameters = new LinkedMultiValueMap<>(); - - this.customizer.customize(clientCredentialsGrantRequest, new HttpHeaders(), parameters); + MultiValueMap parameters = this.converter.convert(clientCredentialsGrantRequest); assertThat(parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE)) .isEqualTo("urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java index 91a7fe3f75a..d2e59e6b1a0 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java @@ -21,7 +21,9 @@ import org.junit.Before; import org.junit.Test; +import org.mockito.InOrder; +import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -42,8 +44,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2AuthorizationCodeGrantRequestEntityConverter}. @@ -60,23 +62,61 @@ public void setup() { } @Test - public void setCustomizerWhenNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setCustomizer(null)) - .withMessage("customizer cannot be null"); + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); } @Test - public void convertWhenCustomizerSetThenCalled() { - OAuth2AuthorizationGrantRequestEntityConverter.Customizer customizer = mock( - OAuth2AuthorizationGrantRequestEntityConverter.Customizer.class); - this.converter.setCustomizer(customizer); + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void convertWhenHeadersConverterSetThenCalled() { + Converter headersConverter1 = mock(Converter.class); + this.converter.setHeadersConverter(headersConverter1); + Converter headersConverter2 = mock(Converter.class); + this.converter.addHeadersConverter(headersConverter2); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizationExchange authorizationExchange = TestOAuth2AuthorizationExchanges.success(); + OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest( + clientRegistration, authorizationExchange); + this.converter.convert(authorizationCodeGrantRequest); + InOrder inOrder = inOrder(headersConverter1, headersConverter2); + inOrder.verify(headersConverter1).convert(any(OAuth2AuthorizationCodeGrantRequest.class)); + inOrder.verify(headersConverter2).convert(any(OAuth2AuthorizationCodeGrantRequest.class)); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() { + Converter> parametersConverter1 = mock( + Converter.class); + this.converter.setParametersConverter(parametersConverter1); + Converter> parametersConverter2 = mock( + Converter.class); + this.converter.addParametersConverter(parametersConverter2); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); OAuth2AuthorizationExchange authorizationExchange = TestOAuth2AuthorizationExchanges.success(); OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest( clientRegistration, authorizationExchange); this.converter.convert(authorizationCodeGrantRequest); - verify(customizer).customize(any(OAuth2AuthorizationCodeGrantRequest.class), any(HttpHeaders.class), - any(MultiValueMap.class)); + InOrder inOrder = inOrder(parametersConverter1, parametersConverter2); + inOrder.verify(parametersConverter1).convert(any(OAuth2AuthorizationCodeGrantRequest.class)); + inOrder.verify(parametersConverter2).convert(any(OAuth2AuthorizationCodeGrantRequest.class)); } @SuppressWarnings("unchecked") diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java index 03c3476aea9..671743ddbdb 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java @@ -18,7 +18,9 @@ import org.junit.Before; import org.junit.Test; +import org.mockito.InOrder; +import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -32,8 +34,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2ClientCredentialsGrantRequestEntityConverter}. @@ -50,22 +52,59 @@ public void setup() { } @Test - public void setCustomizerWhenNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setCustomizer(null)) - .withMessage("customizer cannot be null"); + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); } @Test - public void convertWhenCustomizerSetThenCalled() { - OAuth2AuthorizationGrantRequestEntityConverter.Customizer customizer = mock( - OAuth2AuthorizationGrantRequestEntityConverter.Customizer.class); - this.converter.setCustomizer(customizer); + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void convertWhenHeadersConverterSetThenCalled() { + Converter headersConverter1 = mock(Converter.class); + this.converter.setHeadersConverter(headersConverter1); + Converter headersConverter2 = mock(Converter.class); + this.converter.addHeadersConverter(headersConverter2); + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + this.converter.convert(clientCredentialsGrantRequest); + InOrder inOrder = inOrder(headersConverter1, headersConverter2); + inOrder.verify(headersConverter1).convert(any(OAuth2ClientCredentialsGrantRequest.class)); + inOrder.verify(headersConverter2).convert(any(OAuth2ClientCredentialsGrantRequest.class)); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() { + Converter> parametersConverter1 = mock( + Converter.class); + this.converter.setParametersConverter(parametersConverter1); + Converter> parametersConverter2 = mock( + Converter.class); + this.converter.addParametersConverter(parametersConverter2); ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( clientRegistration); this.converter.convert(clientCredentialsGrantRequest); - verify(customizer).customize(any(OAuth2ClientCredentialsGrantRequest.class), any(HttpHeaders.class), - any(MultiValueMap.class)); + InOrder inOrder = inOrder(parametersConverter1, parametersConverter2); + inOrder.verify(parametersConverter1).convert(any(OAuth2ClientCredentialsGrantRequest.class)); + inOrder.verify(parametersConverter2).convert(any(OAuth2ClientCredentialsGrantRequest.class)); } @SuppressWarnings("unchecked") diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java index 84b5c5aa661..e91359c8b1d 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java @@ -18,7 +18,9 @@ import org.junit.Before; import org.junit.Test; +import org.mockito.InOrder; +import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -32,8 +34,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2PasswordGrantRequestEntityConverter}. @@ -50,22 +52,59 @@ public void setup() { } @Test - public void setCustomizerWhenNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setCustomizer(null)) - .withMessage("customizer cannot be null"); + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); } @Test - public void convertWhenCustomizerSetThenCalled() { - OAuth2AuthorizationGrantRequestEntityConverter.Customizer customizer = mock( - OAuth2AuthorizationGrantRequestEntityConverter.Customizer.class); - this.converter.setCustomizer(customizer); + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void convertWhenHeadersConverterSetThenCalled() { + Converter headersConverter1 = mock(Converter.class); + this.converter.setHeadersConverter(headersConverter1); + Converter headersConverter2 = mock(Converter.class); + this.converter.addHeadersConverter(headersConverter2); + ClientRegistration clientRegistration = TestClientRegistrations.password().build(); + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1", + "password"); + this.converter.convert(passwordGrantRequest); + InOrder inOrder = inOrder(headersConverter1, headersConverter2); + inOrder.verify(headersConverter1).convert(any(OAuth2PasswordGrantRequest.class)); + inOrder.verify(headersConverter2).convert(any(OAuth2PasswordGrantRequest.class)); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() { + Converter> parametersConverter1 = mock( + Converter.class); + this.converter.setParametersConverter(parametersConverter1); + Converter> parametersConverter2 = mock( + Converter.class); + this.converter.addParametersConverter(parametersConverter2); ClientRegistration clientRegistration = TestClientRegistrations.password().build(); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1", "password"); this.converter.convert(passwordGrantRequest); - verify(customizer).customize(any(OAuth2PasswordGrantRequest.class), any(HttpHeaders.class), - any(MultiValueMap.class)); + InOrder inOrder = inOrder(parametersConverter1, parametersConverter2); + inOrder.verify(parametersConverter1).convert(any(OAuth2PasswordGrantRequest.class)); + inOrder.verify(parametersConverter2).convert(any(OAuth2PasswordGrantRequest.class)); } @SuppressWarnings("unchecked") diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java index 104fe3d9696..44beeb4bbb5 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java @@ -20,7 +20,9 @@ import org.junit.Before; import org.junit.Test; +import org.mockito.InOrder; +import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -38,8 +40,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2RefreshTokenGrantRequestEntityConverter}. @@ -56,24 +58,63 @@ public void setup() { } @Test - public void setCustomizerWhenNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setCustomizer(null)) - .withMessage("customizer cannot be null"); + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); } @Test - public void convertWhenCustomizerSetThenCalled() { - OAuth2AuthorizationGrantRequestEntityConverter.Customizer customizer = mock( - OAuth2AuthorizationGrantRequestEntityConverter.Customizer.class); - this.converter.setCustomizer(customizer); + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void convertWhenHeadersConverterSetThenCalled() { + Converter headersConverter1 = mock(Converter.class); + this.converter.setHeadersConverter(headersConverter1); + Converter headersConverter2 = mock(Converter.class); + this.converter.addHeadersConverter(headersConverter2); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + accessToken, refreshToken); + this.converter.convert(refreshTokenGrantRequest); + InOrder inOrder = inOrder(headersConverter1, headersConverter2); + inOrder.verify(headersConverter1).convert(any(OAuth2RefreshTokenGrantRequest.class)); + inOrder.verify(headersConverter2).convert(any(OAuth2RefreshTokenGrantRequest.class)); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() { + Converter> parametersConverter1 = mock( + Converter.class); + this.converter.setParametersConverter(parametersConverter1); + Converter> parametersConverter2 = mock( + Converter.class); + this.converter.addParametersConverter(parametersConverter2); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write"); OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, accessToken, refreshToken); this.converter.convert(refreshTokenGrantRequest); - verify(customizer).customize(any(OAuth2RefreshTokenGrantRequest.class), any(HttpHeaders.class), - any(MultiValueMap.class)); + InOrder inOrder = inOrder(parametersConverter1, parametersConverter2); + inOrder.verify(parametersConverter1).convert(any(OAuth2RefreshTokenGrantRequest.class)); + inOrder.verify(parametersConverter2).convert(any(OAuth2RefreshTokenGrantRequest.class)); } @SuppressWarnings("unchecked") From acb1ec0044445fe134f5c4cb31dc8914455d298d Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 7 Apr 2021 11:47:11 -0400 Subject: [PATCH 3/4] Updates from review 1 --- etc/checkstyle/checkstyle-suppressions.xml | 1 + ...horizationGrantRequestEntityConverter.java | 41 ++- .../oauth2/client/endpoint/JoseHeader.java | 27 +- .../client/endpoint/JwsHeaderConverter.java | 144 --------- .../endpoint/JwtClaimsSetConverter.java | 109 ------- .../oauth2/client/endpoint/JwtEncoder.java | 77 ----- .../client/endpoint/NimbusJwsEncoder.java | 288 +++++++++++++----- ...ientAuthenticationParametersConverter.java | 8 +- 8 files changed, 264 insertions(+), 431 deletions(-) delete mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwsHeaderConverter.java delete mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetConverter.java delete mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncoder.java diff --git a/etc/checkstyle/checkstyle-suppressions.xml b/etc/checkstyle/checkstyle-suppressions.xml index 8ad15ce4afc..6f58307877f 100644 --- a/etc/checkstyle/checkstyle-suppressions.xml +++ b/etc/checkstyle/checkstyle-suppressions.xml @@ -49,4 +49,5 @@ + diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequestEntityConverter.java index 0891e413b30..7da5d530445 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequestEntityConverter.java @@ -39,27 +39,21 @@ * @see AbstractOAuth2AuthorizationGrantRequest * @see RequestEntity */ -public abstract class AbstractOAuth2AuthorizationGrantRequestEntityConverter +abstract class AbstractOAuth2AuthorizationGrantRequestEntityConverter implements Converter> { // @formatter:off - protected Converter headersConverter = + private Converter headersConverter = (authorizationGrantRequest) -> OAuth2AuthorizationGrantRequestEntityUtils .getTokenRequestHeaders(authorizationGrantRequest.getClientRegistration()); // @formatter:on - protected Converter> parametersConverter = this::createParameters; - - /** - * Sub-class constructor. - */ - protected AbstractOAuth2AuthorizationGrantRequestEntityConverter() { - } + private Converter> parametersConverter = this::createParameters; @Override public RequestEntity convert(T authorizationGrantRequest) { - HttpHeaders headers = this.headersConverter.convert(authorizationGrantRequest); - MultiValueMap parameters = this.parametersConverter.convert(authorizationGrantRequest); + HttpHeaders headers = getHeadersConverter().convert(authorizationGrantRequest); + MultiValueMap parameters = getParametersConverter().convert(authorizationGrantRequest); URI uri = UriComponentsBuilder .fromUriString(authorizationGrantRequest.getClientRegistration().getProviderDetails().getTokenUri()) .build().toUri(); @@ -73,7 +67,18 @@ public RequestEntity convert(T authorizationGrantRequest) { * @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access * Token Request body */ - protected abstract MultiValueMap createParameters(T authorizationGrantRequest); + abstract MultiValueMap createParameters(T authorizationGrantRequest); + + /** + * Returns the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders} + * used in the OAuth 2.0 Access Token Request headers. + * @return the {@link Converter} used for converting the + * {@link OAuth2AuthorizationCodeGrantRequest} to {@link HttpHeaders} + */ + final Converter getHeadersConverter() { + return this.headersConverter; + } /** * Sets the {@link Converter} used for converting the @@ -113,6 +118,18 @@ public final void addHeadersConverter(Converter headersConverter }; } + /** + * Returns the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} + * of the parameters used in the OAuth 2.0 Access Token Request body. + * @return the {@link Converter} used for converting the + * {@link OAuth2AuthorizationCodeGrantRequest} to a {@link MultiValueMap} of the + * parameters + */ + final Converter> getParametersConverter() { + return this.parametersConverter; + } + /** * Sets the {@link Converter} used for converting the * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java index e2d01b52ddf..8d06f3e3f18 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java @@ -119,7 +119,9 @@ URL getX509Uri() { /** * Returns the X.509 certificate chain that contains the X.509 public key certificate * or certificate chain corresponding to the key used to digitally sign the JWS or - * encrypt the JWE. + * encrypt the JWE. The certificate or certificate chain is represented as a + * {@code List} of certificate value {@code String}s. Each {@code String} in the + * {@code List} is a Base64-encoded DER PKIX certificate value. * @return the X.509 certificate chain */ List getX509CertificateChain() { @@ -245,7 +247,7 @@ Builder algorithm(JwaAlgorithm jwaAlgorithm) { * @return the {@link Builder} */ Builder jwkSetUri(String jwkSetUri) { - return header(JoseHeaderNames.JKU, jwkSetUri); + return header(JoseHeaderNames.JKU, convertAsURL(JoseHeaderNames.JKU, jwkSetUri)); } /** @@ -276,13 +278,15 @@ Builder keyId(String keyId) { * @return the {@link Builder} */ Builder x509Uri(String x509Uri) { - return header(JoseHeaderNames.X5U, x509Uri); + return header(JoseHeaderNames.X5U, convertAsURL(JoseHeaderNames.X5U, x509Uri)); } /** * Sets the X.509 certificate chain that contains the X.509 public key certificate * or certificate chain corresponding to the key used to digitally sign the JWS or - * encrypt the JWE. + * encrypt the JWE. The certificate or certificate chain is represented as a + * {@code List} of certificate value {@code String}s. Each {@code String} in the + * {@code List} is a Base64-encoded DER PKIX certificate value. * @param x509CertificateChain the X.509 certificate chain * @return the {@link Builder} */ @@ -371,19 +375,14 @@ Builder headers(Consumer> headersConsumer) { */ JoseHeader build() { Assert.notEmpty(this.headers, "headers cannot be empty"); - convertAsURL(JoseHeaderNames.JKU); - convertAsURL(JoseHeaderNames.X5U); return new JoseHeader(this.headers); } - private void convertAsURL(String header) { - Object value = this.headers.get(header); - if (value != null) { - URL convertedValue = ClaimConversionService.getSharedInstance().convert(value, URL.class); - Assert.isTrue(convertedValue != null, - () -> "Unable to convert header '" + header + "' of type '" + value.getClass() + "' to URL."); - this.headers.put(header, convertedValue); - } + private static URL convertAsURL(String header, String value) { + URL convertedValue = ClaimConversionService.getSharedInstance().convert(value, URL.class); + Assert.isTrue(convertedValue != null, + () -> "Unable to convert header '" + header + "' of type '" + value.getClass() + "' to URL."); + return convertedValue; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwsHeaderConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwsHeaderConverter.java deleted file mode 100644 index da3ccb32bc6..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwsHeaderConverter.java +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.endpoint; - -import java.net.URL; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -import com.nimbusds.jose.JOSEObjectType; -import com.nimbusds.jose.JWSAlgorithm; -import com.nimbusds.jose.JWSHeader; -import com.nimbusds.jose.jwk.JWK; -import com.nimbusds.jose.util.Base64; -import com.nimbusds.jose.util.Base64URL; - -import org.springframework.core.convert.converter.Converter; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; - -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - -/** - * A {@link Converter} that converts a {@link JoseHeader} to - * {@code com.nimbusds.jose.JWSHeader}. - * - * @author Joe Grandja - * @since 5.5 - * @see Converter - * @see JoseHeader - * @see com.nimbusds.jose.JWSHeader - */ -final class JwsHeaderConverter implements Converter { - - @Override - public JWSHeader convert(JoseHeader headers) { - JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.parse(headers.getAlgorithm().getName())); - - URL jwkSetUri = headers.getJwkSetUri(); - if (jwkSetUri != null) { - try { - builder.jwkURL(jwkSetUri.toURI()); - } - catch (Exception ex) { - throw new IllegalArgumentException( - "Unable to convert '" + JoseHeaderNames.JKU + "' JOSE header to a URI", ex); - } - } - - Map jwk = headers.getJwk(); - if (!CollectionUtils.isEmpty(jwk)) { - try { - builder.jwk(JWK.parse(jwk)); - } - catch (Exception ex) { - throw new IllegalArgumentException("Unable to convert '" + JoseHeaderNames.JWK + "' JOSE header", ex); - } - } - - String keyId = headers.getKeyId(); - if (StringUtils.hasText(keyId)) { - builder.keyID(keyId); - } - - URL x509Uri = headers.getX509Uri(); - if (x509Uri != null) { - try { - builder.x509CertURL(x509Uri.toURI()); - } - catch (Exception ex) { - throw new IllegalArgumentException( - "Unable to convert '" + JoseHeaderNames.X5U + "' JOSE header to a URI", ex); - } - } - - List x509CertificateChain = headers.getX509CertificateChain(); - if (!CollectionUtils.isEmpty(x509CertificateChain)) { - builder.x509CertChain(x509CertificateChain.stream().map(Base64::new).collect(Collectors.toList())); - } - - String x509SHA1Thumbprint = headers.getX509SHA1Thumbprint(); - if (StringUtils.hasText(x509SHA1Thumbprint)) { - builder.x509CertThumbprint(new Base64URL(x509SHA1Thumbprint)); - } - - String x509SHA256Thumbprint = headers.getX509SHA256Thumbprint(); - if (StringUtils.hasText(x509SHA256Thumbprint)) { - builder.x509CertSHA256Thumbprint(new Base64URL(x509SHA256Thumbprint)); - } - - String type = headers.getType(); - if (StringUtils.hasText(type)) { - builder.type(new JOSEObjectType(type)); - } - - String contentType = headers.getContentType(); - if (StringUtils.hasText(contentType)) { - builder.contentType(contentType); - } - - Set critical = headers.getCritical(); - if (!CollectionUtils.isEmpty(critical)) { - builder.criticalParams(critical); - } - - Map customHeaders = headers.getHeaders().entrySet().stream() - .filter((header) -> !JWSHeader.getRegisteredParameterNames().contains(header.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - if (!CollectionUtils.isEmpty(customHeaders)) { - builder.customParams(customHeaders); - } - - return builder.build(); - } - -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetConverter.java deleted file mode 100644 index 58cdd51fd28..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetConverter.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.endpoint; - -import java.time.Instant; -import java.util.Date; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -import com.nimbusds.jwt.JWTClaimsSet; - -import org.springframework.core.convert.converter.Converter; -import org.springframework.security.oauth2.jwt.JwtClaimNames; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; - -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - -/** - * A {@link Converter} that converts a {@link JwtClaimsSet} to - * {@code com.nimbusds.jwt.JWTClaimsSet}. - * - * @author Joe Grandja - * @since 5.5 - * @see Converter - * @see JwtClaimsSet - * @see com.nimbusds.jwt.JWTClaimsSet - */ -final class JwtClaimsSetConverter implements Converter { - - @Override - public JWTClaimsSet convert(JwtClaimsSet claims) { - JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder(); - - // NOTE: The value of the 'iss' claim is a String or URL (StringOrURI). - Object issuer = claims.getClaim(JwtClaimNames.ISS); - if (issuer != null) { - builder.issuer(issuer.toString()); - } - - String subject = claims.getSubject(); - if (StringUtils.hasText(subject)) { - builder.subject(subject); - } - - List audience = claims.getAudience(); - if (!CollectionUtils.isEmpty(audience)) { - builder.audience(audience); - } - - Instant expiresAt = claims.getExpiresAt(); - if (expiresAt != null) { - builder.expirationTime(Date.from(expiresAt)); - } - - Instant notBefore = claims.getNotBefore(); - if (notBefore != null) { - builder.notBeforeTime(Date.from(notBefore)); - } - - Instant issuedAt = claims.getIssuedAt(); - if (issuedAt != null) { - builder.issueTime(Date.from(issuedAt)); - } - - String jwtId = claims.getId(); - if (StringUtils.hasText(jwtId)) { - builder.jwtID(jwtId); - } - - Map customClaims = claims.getClaims().entrySet().stream() - .filter((claim) -> !JWTClaimsSet.getRegisteredNames().contains(claim.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - if (!CollectionUtils.isEmpty(customClaims)) { - customClaims.forEach(builder::claim); - } - - return builder.build(); - } - -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncoder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncoder.java deleted file mode 100644 index fc20abf626e..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncoder.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.endpoint; - -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtDecoder; - -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - -/** - * Implementations of this interface are responsible for encoding a JSON Web Token (JWT) - * to it's compact claims representation format. - * - *

- * JWTs may be represented using the JWS Compact Serialization format for a JSON Web - * Signature (JWS) structure or JWE Compact Serialization format for a JSON Web Encryption - * (JWE) structure. Therefore, implementors are responsible for signing a JWS and/or - * encrypting a JWE. - * - * @author Anoop Garlapati - * @author Joe Grandja - * @since 5.5 - * @see Jwt - * @see JoseHeader - * @see JwtClaimsSet - * @see JwtDecoder - * @see JSON Web Token - * (JWT) - * @see JSON Web Signature - * (JWS) - * @see JSON Web Encryption - * (JWE) - * @see JWS - * Compact Serialization - * @see JWE - * Compact Serialization - */ -@FunctionalInterface -interface JwtEncoder { - - /** - * Encode the JWT to it's compact claims representation format. - * @param headers the JOSE header - * @param claims the JWT Claims Set - * @return a {@link Jwt} - * @throws JwtEncodingException if an error occurs while attempting to encode the JWT - */ - Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException; - -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java index 08936959e2b..fbf82cac453 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java @@ -16,28 +16,40 @@ package org.springframework.security.oauth2.client.endpoint; +import java.net.URI; +import java.net.URL; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JOSEObjectType; +import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSSigner; -import com.nimbusds.jose.KeySourceException; import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKMatcher; import com.nimbusds.jose.jwk.JWKSelector; +import com.nimbusds.jose.jwk.KeyType; +import com.nimbusds.jose.jwk.KeyUse; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.produce.JWSSignerFactory; +import com.nimbusds.jose.util.Base64; +import com.nimbusds.jose.util.Base64URL; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; -import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; /* @@ -56,17 +68,16 @@ */ /** - * An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the - * JSON Web Signature (JWS) Compact Serialization format. The private/secret key used for - * signing the JWS is supplied by the {@code com.nimbusds.jose.jwk.source.JWKSource} - * provided via the constructor. + * A JWT encoder that encodes a JSON Web Token (JWT) using the JSON Web Signature (JWS) + * Compact Serialization format. The private/secret key used for signing the JWS is + * supplied by the {@code com.nimbusds.jose.jwk.source.JWKSource} provided via the + * constructor. * *

* NOTE: This implementation uses the Nimbus JOSE + JWT SDK. * * @author Joe Grandja * @since 5.5 - * @see JwtEncoder * @see com.nimbusds.jose.jwk.source.JWKSource * @see com.nimbusds.jose.jwk.JWK * @see JSON Web Token @@ -78,14 +89,10 @@ * @see Nimbus * JOSE + JWT SDK */ -final class NimbusJwsEncoder implements JwtEncoder { +final class NimbusJwsEncoder { private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s"; - private static final Converter JWS_HEADER_CONVERTER = new JwsHeaderConverter(); - - private static final Converter JWT_CLAIMS_SET_CONVERTER = new JwtClaimsSetConverter(); - private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory(); private final Map jwsSigners = new ConcurrentHashMap<>(); @@ -101,39 +108,47 @@ final class NimbusJwsEncoder implements JwtEncoder { this.jwkSource = jwkSource; } - @Override - public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { + Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { Assert.notNull(headers, "headers cannot be null"); Assert.notNull(claims, "claims cannot be null"); - JWSHeader jwsHeader; + JWK jwk = selectJwk(headers); + headers = addKeyIdentifierHeadersIfNecessary(headers, jwk); + + String jws = serialize(headers, claims, jwk); + + return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims()); + } + + private JWK selectJwk(JoseHeader headers) { + List jwks; try { - jwsHeader = JWS_HEADER_CONVERTER.convert(headers); + JWKSelector jwkSelector = new JWKSelector(createJwkMatcher(headers)); + jwks = this.jwkSource.get(jwkSelector, null); } catch (Exception ex) { - throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Failed to select a JWK signing key -> " + ex.getMessage()), ex); } - JWK jwk = selectJwk(jwsHeader); - if (jwk == null) { + if (jwks.size() > 1) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'")); + } + + if (jwks.isEmpty()) { throw new JwtEncodingException( String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); } - jwsHeader = addKeyIdentifierHeadersIfNecessary(jwsHeader, jwk); - headers = syncKeyIdentifierHeadersIfNecessary(headers, jwsHeader); + return jwks.get(0); + } - JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims); + private String serialize(JoseHeader headers, JwtClaimsSet claims, JWK jwk) { + JWSHeader jwsHeader = convert(headers); + JWTClaimsSet jwtClaimsSet = convert(claims); - JWSSigner jwsSigner = this.jwsSigners.computeIfAbsent(jwk, (key) -> { - try { - return JWS_SIGNER_FACTORY.createJWSSigner(key); - } - catch (JOSEException ex) { - throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, - "Failed to create a JWS Signer -> " + ex.getMessage()), ex); - } - }); + JWSSigner jwsSigner = this.jwsSigners.computeIfAbsent(jwk, NimbusJwsEncoder::createSigner); SignedJWT signedJwt = new SignedJWT(jwsHeader, jwtClaimsSet); try { @@ -143,71 +158,202 @@ public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingExc throw new JwtEncodingException( String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to sign the JWT -> " + ex.getMessage()), ex); } - String jws = signedJwt.serialize(); + return signedJwt.serialize(); + } + + private static JWKMatcher createJwkMatcher(JoseHeader headers) { + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(headers.getAlgorithm().getName()); + + if (JWSAlgorithm.Family.RSA.contains(jwsAlgorithm) || JWSAlgorithm.Family.EC.contains(jwsAlgorithm)) { + // @formatter:off + return new JWKMatcher.Builder() + .keyType(KeyType.forAlgorithm(jwsAlgorithm)) + .keyID(headers.getKeyId()) + .keyUses(KeyUse.SIGNATURE, null) + .algorithms(jwsAlgorithm, null) + .x509CertSHA256Thumbprint(Base64URL.from(headers.getX509SHA256Thumbprint())) + .build(); + // @formatter:on + } + else if (JWSAlgorithm.Family.HMAC_SHA.contains(jwsAlgorithm)) { + // @formatter:off + return new JWKMatcher.Builder() + .keyType(KeyType.forAlgorithm(jwsAlgorithm)) + .keyID(headers.getKeyId()) + .privateOnly(true) + .algorithms(jwsAlgorithm, null) + .build(); + // @formatter:on + } - return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims()); + return null; } - private JWK selectJwk(JWSHeader jwsHeader) { - JWKSelector jwkSelector = new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader)); + private static JoseHeader addKeyIdentifierHeadersIfNecessary(JoseHeader headers, JWK jwk) { + // Check if headers have already been added + if (StringUtils.hasText(headers.getKeyId()) && StringUtils.hasText(headers.getX509SHA256Thumbprint())) { + return headers; + } + // Check if headers can be added from JWK + if (!StringUtils.hasText(jwk.getKeyID()) && jwk.getX509CertSHA256Thumbprint() == null) { + return headers; + } - List jwks; + JoseHeader.Builder headersBuilder = JoseHeader.from(headers); + if (!StringUtils.hasText(headers.getKeyId()) && StringUtils.hasText(jwk.getKeyID())) { + headersBuilder.keyId(jwk.getKeyID()); + } + if (!StringUtils.hasText(headers.getX509SHA256Thumbprint()) && jwk.getX509CertSHA256Thumbprint() != null) { + headersBuilder.x509SHA256Thumbprint(jwk.getX509CertSHA256Thumbprint().toString()); + } + + return headersBuilder.build(); + } + + private static JWSSigner createSigner(JWK jwk) { try { - jwks = this.jwkSource.get(jwkSelector, null); + return JWS_SIGNER_FACTORY.createJWSSigner(jwk); } - catch (KeySourceException ex) { + catch (JOSEException ex) { throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, - "Failed to select a JWK signing key -> " + ex.getMessage()), ex); + "Failed to create a JWS Signer -> " + ex.getMessage()), ex); } + } - if (jwks.size() > 1) { - throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, - "Found multiple JWK signing keys for algorithm '" + jwsHeader.getAlgorithm().getName() + "'")); + private static JWSHeader convert(JoseHeader headers) { + JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.parse(headers.getAlgorithm().getName())); + + if (headers.getJwkSetUri() != null) { + builder.jwkURL(convertAsURI(JoseHeaderNames.JKU, headers.getJwkSetUri())); } - return !jwks.isEmpty() ? jwks.get(0) : null; - } + Map jwk = headers.getJwk(); + if (!CollectionUtils.isEmpty(jwk)) { + try { + builder.jwk(JWK.parse(jwk)); + } + catch (Exception ex) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Unable to convert '" + JoseHeaderNames.JWK + "' JOSE header"), ex); + } + } - private static JWSHeader addKeyIdentifierHeadersIfNecessary(JWSHeader jwsHeader, JWK jwk) { - // Check if headers have already been added - if (StringUtils.hasText(jwsHeader.getKeyID()) && jwsHeader.getX509CertSHA256Thumbprint() != null) { - return jwsHeader; + String keyId = headers.getKeyId(); + if (StringUtils.hasText(keyId)) { + builder.keyID(keyId); } - // Check if headers can be added from JWK - if (!StringUtils.hasText(jwk.getKeyID()) && jwk.getX509CertSHA256Thumbprint() == null) { - return jwsHeader; + + if (headers.getX509Uri() != null) { + builder.x509CertURL(convertAsURI(JoseHeaderNames.X5U, headers.getX509Uri())); + } + + List x509CertificateChain = headers.getX509CertificateChain(); + if (!CollectionUtils.isEmpty(x509CertificateChain)) { + List x5cList = new ArrayList<>(); + x509CertificateChain.forEach((x5c) -> x5cList.add(new Base64(x5c))); + if (!x5cList.isEmpty()) { + builder.x509CertChain(x5cList); + } + } + + String x509SHA1Thumbprint = headers.getX509SHA1Thumbprint(); + if (StringUtils.hasText(x509SHA1Thumbprint)) { + builder.x509CertThumbprint(new Base64URL(x509SHA1Thumbprint)); + } + + String x509SHA256Thumbprint = headers.getX509SHA256Thumbprint(); + if (StringUtils.hasText(x509SHA256Thumbprint)) { + builder.x509CertSHA256Thumbprint(new Base64URL(x509SHA256Thumbprint)); } - JWSHeader.Builder headerBuilder = new JWSHeader.Builder(jwsHeader); - if (!StringUtils.hasText(jwsHeader.getKeyID()) && StringUtils.hasText(jwk.getKeyID())) { - headerBuilder.keyID(jwk.getKeyID()); + String type = headers.getType(); + if (StringUtils.hasText(type)) { + builder.type(new JOSEObjectType(type)); } - if (jwsHeader.getX509CertSHA256Thumbprint() == null && jwk.getX509CertSHA256Thumbprint() != null) { - headerBuilder.x509CertSHA256Thumbprint(jwk.getX509CertSHA256Thumbprint()); + + String contentType = headers.getContentType(); + if (StringUtils.hasText(contentType)) { + builder.contentType(contentType); + } + + Set critical = headers.getCritical(); + if (!CollectionUtils.isEmpty(critical)) { + builder.criticalParams(critical); + } + + Map customHeaders = new HashMap<>(); + headers.getHeaders().forEach((name, value) -> { + if (!JWSHeader.getRegisteredParameterNames().contains(name)) { + customHeaders.put(name, value); + } + }); + if (!customHeaders.isEmpty()) { + builder.customParams(customHeaders); } - return headerBuilder.build(); + return builder.build(); } - private static JoseHeader syncKeyIdentifierHeadersIfNecessary(JoseHeader joseHeader, JWSHeader jwsHeader) { - String jwsHeaderX509SHA256Thumbprint = null; - if (jwsHeader.getX509CertSHA256Thumbprint() != null) { - jwsHeaderX509SHA256Thumbprint = jwsHeader.getX509CertSHA256Thumbprint().toString(); + private static JWTClaimsSet convert(JwtClaimsSet claims) { + JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder(); + + // NOTE: The value of the 'iss' claim is a String or URL (StringOrURI). + Object issuer = claims.getClaim(JwtClaimNames.ISS); + if (issuer != null) { + builder.issuer(issuer.toString()); + } + + String subject = claims.getSubject(); + if (StringUtils.hasText(subject)) { + builder.subject(subject); + } + + List audience = claims.getAudience(); + if (!CollectionUtils.isEmpty(audience)) { + builder.audience(audience); + } + + Instant expiresAt = claims.getExpiresAt(); + if (expiresAt != null) { + builder.expirationTime(Date.from(expiresAt)); + } + + Instant notBefore = claims.getNotBefore(); + if (notBefore != null) { + builder.notBeforeTime(Date.from(notBefore)); } - if (Objects.equals(joseHeader.getKeyId(), jwsHeader.getKeyID()) - && Objects.equals(joseHeader.getX509SHA256Thumbprint(), jwsHeaderX509SHA256Thumbprint)) { - return joseHeader; + + Instant issuedAt = claims.getIssuedAt(); + if (issuedAt != null) { + builder.issueTime(Date.from(issuedAt)); } - JoseHeader.Builder headerBuilder = JoseHeader.from(joseHeader); - if (!Objects.equals(joseHeader.getKeyId(), jwsHeader.getKeyID())) { - headerBuilder.keyId(jwsHeader.getKeyID()); + String jwtId = claims.getId(); + if (StringUtils.hasText(jwtId)) { + builder.jwtID(jwtId); } - if (!Objects.equals(joseHeader.getX509SHA256Thumbprint(), jwsHeaderX509SHA256Thumbprint)) { - headerBuilder.x509SHA256Thumbprint(jwsHeaderX509SHA256Thumbprint); + + Map customClaims = new HashMap<>(); + claims.getClaims().forEach((name, value) -> { + if (!JWTClaimsSet.getRegisteredNames().contains(name)) { + customClaims.put(name, value); + } + }); + if (!customClaims.isEmpty()) { + customClaims.forEach(builder::claim); } - return headerBuilder.build(); + return builder.build(); + } + + private static URI convertAsURI(String header, URL url) { + try { + return url.toURI(); + } + catch (Exception ex) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Unable to convert '" + header + "' JOSE header to a URI"), ex); + } } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index aabaa92cd71..1a603b09f21 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -79,7 +79,7 @@ public final class NimbusJwtClientAuthenticationParametersConverter jwkResolver; - private final Map jwtEncoders = new ConcurrentHashMap<>(); + private final Map jwtEncoders = new ConcurrentHashMap<>(); private JwtCustomizer jwtCustomizer = (request, headers, claims) -> { }; @@ -127,7 +127,7 @@ public MultiValueMap convert(T authorizationGrantRequest) { JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(jwsAlgorithm); Instant issuedAt = Instant.now(); - Instant expiresAt = issuedAt.plus(Duration.ofSeconds(30)); + Instant expiresAt = issuedAt.plus(Duration.ofSeconds(60)); // @formatter:off JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder() @@ -144,7 +144,7 @@ public MultiValueMap convert(T authorizationGrantRequest) { JoseHeader joseHeader = headersBuilder.build(); JwtClaimsSet jwtClaimsSet = claimsBuilder.build(); - JwtEncoder jwsEncoder = this.jwtEncoders.computeIfAbsent(clientRegistration.getRegistrationId(), + NimbusJwsEncoder jwsEncoder = this.jwtEncoders.computeIfAbsent(clientRegistration.getRegistrationId(), (clientRegistrationId) -> { JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); return new NimbusJwsEncoder(jwkSource); @@ -202,7 +202,7 @@ else if (KeyType.OCT.equals(jwk.getKeyType())) { * @param the type of {@link AbstractOAuth2AuthorizationGrantRequest} */ @FunctionalInterface - interface JwtCustomizer { + interface JwtCustomizer { /** * Customize the {@link Jwt} headers and/or claims. From 20c0ea5a59f54a7fa1258806b1ec2587d5f84af4 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 8 Apr 2021 06:30:14 -0400 Subject: [PATCH 4/4] Replace JwtCustomizer with JwtClientAuthenticationContext --- ...ientAuthenticationParametersConverter.java | 69 ++++++++++++++----- ...uthenticationParametersConverterTests.java | 6 +- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index 1a603b09f21..4c2d4ac35cd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import java.util.function.Function; import com.nimbusds.jose.jwk.JWK; @@ -60,7 +61,7 @@ * @since 5.5 * @see Converter * @see com.nimbusds.jose.jwk.JWK - * @see JwtCustomizer + * @see JwtClientAuthenticationContext * @see 2.2 * Using JWTs for Client Authentication * @see 4.2 @@ -79,14 +80,14 @@ public final class NimbusJwtClientAuthenticationParametersConverter jwkResolver; - private final Map jwtEncoders = new ConcurrentHashMap<>(); + private final Map jwsEncoders = new ConcurrentHashMap<>(); - private JwtCustomizer jwtCustomizer = (request, headers, claims) -> { + private Consumer> jwtCustomizer = (context) -> { }; /** - * Constructs a {@code NimbusJwtClientAuthenticationCustomizer} using the provided - * parameters. + * Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the + * provided parameters. * @param jwkResolver the resolver that provides the {@code com.nimbusds.jose.jwk.JWK} * associated to the {@link ClientRegistration client} */ @@ -139,12 +140,14 @@ public MultiValueMap convert(T authorizationGrantRequest) { .expiresAt(expiresAt); // @formatter:on - this.jwtCustomizer.customize(authorizationGrantRequest, headersBuilder.headers, claimsBuilder.claims); + JwtClientAuthenticationContext context = new JwtClientAuthenticationContext<>(authorizationGrantRequest, + headersBuilder.headers, claimsBuilder.claims); + this.jwtCustomizer.accept(context); JoseHeader joseHeader = headersBuilder.build(); JwtClaimsSet jwtClaimsSet = claimsBuilder.build(); - NimbusJwsEncoder jwsEncoder = this.jwtEncoders.computeIfAbsent(clientRegistration.getRegistrationId(), + NimbusJwsEncoder jwsEncoder = this.jwsEncoders.computeIfAbsent(clientRegistration.getRegistrationId(), (clientRegistrationId) -> { JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); return new NimbusJwsEncoder(jwkSource); @@ -160,12 +163,12 @@ public MultiValueMap convert(T authorizationGrantRequest) { } /** - * Sets the {@link JwtCustomizer} to be provided the opportunity to customize the + * Sets the {@link Consumer} to be provided the opportunity to customize the * {@link Jwt} headers and/or claims. - * @param jwtCustomizer the {@link JwtCustomizer} to be provided the opportunity to + * @param jwtCustomizer the {@link Consumer} to be provided the opportunity to * customize the {@link Jwt} headers and/or claims */ - public void setJwtCustomizer(JwtCustomizer jwtCustomizer) { + public void setJwtCustomizer(Consumer> jwtCustomizer) { Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); this.jwtCustomizer = jwtCustomizer; } @@ -196,21 +199,49 @@ else if (KeyType.OCT.equals(jwk.getKeyType())) { } /** - * Implementations of this interface are provided the opportunity to customize the - * {@link Jwt} headers and/or claims. + * A context that provides access to the {@link Jwt} headers and/or claims allowing + * for customization. * * @param the type of {@link AbstractOAuth2AuthorizationGrantRequest} */ - @FunctionalInterface - interface JwtCustomizer { + public static final class JwtClientAuthenticationContext { + + private final T authorizationGrantRequest; + + private final Map headers; + + private final Map claims; + + private JwtClientAuthenticationContext(T authorizationGrantRequest, Map headers, + Map claims) { + this.authorizationGrantRequest = authorizationGrantRequest; + this.headers = headers; + this.claims = claims; + } /** - * Customize the {@link Jwt} headers and/or claims. - * @param authorizationGrantRequest the authorization grant request - * @param headers the headers - * @param claims the claims + * Returns the authorization grant request. + * @return the authorization grant request */ - void customize(T authorizationGrantRequest, Map headers, Map claims); + public T getAuthorizationGrantRequest() { + return this.authorizationGrantRequest; + } + + /** + * Returns the JOSE header(s). + * @return a {@code Map} of the JOSE header(s) + */ + public Map getHeaders() { + return this.headers; + } + + /** + * Returns the JWT Claims Set. + * @return a {@code Map} of the JWT Claims Set + */ + public Map getClaims() { + return this.claims; + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java index 234b436151d..2102dc96ed0 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java @@ -116,8 +116,7 @@ public void convertWhenPrivateKeyJwtClientAuthenticationMethodThenCustomized() t given(this.jwkResolver.apply(any())).willReturn(rsaJwk); // Add custom claim - this.converter.setJwtCustomizer( - (authorizationGrantRequest, headers, claims) -> claims.put("custom-claim", "custom-value")); + this.converter.setJwtCustomizer((context) -> context.getClaims().put("custom-claim", "custom-value")); // @formatter:off ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials() @@ -155,8 +154,7 @@ public void convertWhenClientSecretJwtClientAuthenticationMethodThenCustomized() given(this.jwkResolver.apply(any())).willReturn(secretJwk); // Add custom claim - this.converter.setJwtCustomizer( - (authorizationGrantRequest, headers, claims) -> claims.put("custom-claim", "custom-value")); + this.converter.setJwtCustomizer((context) -> context.getClaims().put("custom-claim", "custom-value")); // @formatter:off ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()