Skip to content

Simplify customizing OAuth2AuthorizationRequest #7748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,6 +41,7 @@
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;

/**
* An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
Expand All @@ -66,6 +67,7 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
private final AntPathRequestMatcher authorizationRequestMatcher;
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
private Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer = customizer -> {};

/**
* Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided parameters.
Expand Down Expand Up @@ -98,6 +100,18 @@ public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String reg
return resolve(request, registrationId, redirectUriAction);
}

/**
* Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
* allowing for further customizations.
*
* @since 5.3
* @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
*/
public void setAuthorizationRequestCustomizer(Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer) {
Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
this.authorizationRequestCustomizer = authorizationRequestCustomizer;
}

private String getAction(HttpServletRequest request, String defaultAction) {
String action = request.getParameter("action");
if (action == null) {
Expand Down Expand Up @@ -144,16 +158,17 @@ private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String re

String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction);

OAuth2AuthorizationRequest authorizationRequest = builder
builder
.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr)
.scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.attributes(attributes)
.build();
.attributes(attributes);

this.authorizationRequestCustomizer.accept(builder);

return authorizationRequest;
return builder.build();
}

private String resolveRegistrationId(HttpServletRequest request) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,6 +46,7 @@
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;

/**
* The default implementation of {@link ServerOAuth2AuthorizationRequestResolver}.
Expand Down Expand Up @@ -81,6 +82,8 @@ public class DefaultServerOAuth2AuthorizationRequestResolver

private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);

private Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer = customizer -> {};

/**
* Creates a new instance
* @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration}
Expand Down Expand Up @@ -121,6 +124,18 @@ public Mono<OAuth2AuthorizationRequest> resolve(ServerWebExchange exchange,
.map(clientRegistration -> authorizationRequest(exchange, clientRegistration));
}

/**
* Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
* allowing for further customizations.
*
* @since 5.3
* @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
*/
public final void setAuthorizationRequestCustomizer(Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer) {
Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
this.authorizationRequestCustomizer = authorizationRequestCustomizer;
}

private Mono<ClientRegistration> findByRegistrationId(ServerWebExchange exchange, String clientRegistration) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistration)
.switchIfEmpty(Mono.error(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid client registration id")));
Expand Down Expand Up @@ -155,13 +170,17 @@ private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchan
"Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue()
+ ") for Client Registration with Id: " + clientRegistration.getRegistrationId());
}
return builder
builder
.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr).scopes(clientRegistration.getScopes())
.redirectUri(redirectUriStr)
.scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.attributes(attributes)
.build();
.attributes(attributes);

this.authorizationRequestCustomizer.accept(builder);

return builder.build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -70,8 +71,8 @@ public void serializeWhenRequiredAttributesOnlyThenSerializes() throws Exception
this.authorizationRequestBuilder
.scopes(null)
.state(null)
.additionalParameters(null)
.attributes(null)
.additionalParameters(Collections.emptyMap())
.attributes(Collections.emptyMap())
.build();
String expectedJson = asJson(authorizationRequest);
String json = this.mapper.writeValueAsString(authorizationRequest);
Expand Down Expand Up @@ -118,8 +119,8 @@ public void deserializeWhenRequiredAttributesOnlyThenDeserializes() throws Excep
this.authorizationRequestBuilder
.scopes(null)
.state(null)
.additionalParameters(null)
.attributes(null)
.additionalParameters(Collections.emptyMap())
.attributes(Collections.emptyMap())
.build();
String json = asJson(expectedAuthorizationRequest);
OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,7 +31,9 @@
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;

import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.entry;

/**
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
Expand Down Expand Up @@ -81,6 +83,12 @@ public void constructorWhenAuthorizationRequestBaseUriIsNullThenThrowIllegalArgu
.isInstanceOf(IllegalArgumentException.class);
}

@Test
public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null))
.isInstanceOf(IllegalArgumentException.class);
}

@Test
public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() {
String requestUri = "/path";
Expand Down Expand Up @@ -414,6 +422,76 @@ public void resolveWhenAuthenticationRequestWithValidOidcClientThenResolves() {
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
}

// gh-7696
@Test
public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
ClientRegistration clientRegistration = this.oidcRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);

this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer
.additionalParameters(params -> params.remove(OidcParameterNames.NONCE))
.attributes(attrs -> attrs.remove(OidcParameterNames.NONCE)));

OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&client_id=client-id&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id");
}

@Test
public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
ClientRegistration clientRegistration = this.oidcRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);

this.resolver.setAuthorizationRequestCustomizer(customizer ->
customizer.authorizationRequestUri(uriBuilder -> {
uriBuilder.queryParam("param1", "value1");
return uriBuilder.build();
})
);

OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&client_id=client-id&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" +
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
"param1=value1");
}

@Test
public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
ClientRegistration clientRegistration = this.oidcRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);

this.resolver.setAuthorizationRequestCustomizer(customizer ->
customizer.parameters(params -> {
params.put("appid", params.get("client_id"));
params.remove("client_id");
})
);

OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" +
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
"appid=client-id");
}

private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() {
return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration")
.redirectUriTemplate("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,6 +37,7 @@
import reactor.core.publisher.Mono;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.catchThrowableOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
Expand All @@ -59,6 +60,12 @@ public void setup() {
this.resolver = new DefaultServerOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository);
}

@Test
public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null))
.isInstanceOf(IllegalArgumentException.class);
}

@Test
public void resolveWhenNotMatchThenNull() {
assertThat(resolve("/")).isNull();
Expand Down Expand Up @@ -139,6 +146,79 @@ public void resolveWhenAuthenticationRequestWithValidOidcClientThenResolves() {
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
}

// gh-7696
@Test
public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
Mono.just(TestClientRegistrations.clientRegistration()
.scope(OidcScopes.OPENID)
.build()));

this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer
.additionalParameters(params -> params.remove(OidcParameterNames.NONCE))
.attributes(attrs -> attrs.remove(OidcParameterNames.NONCE)));

OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");

assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&client_id=client-id&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=/login/oauth2/code/registration-id");
}

@Test
public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
Mono.just(TestClientRegistrations.clientRegistration()
.scope(OidcScopes.OPENID)
.build()));

this.resolver.setAuthorizationRequestCustomizer(customizer ->
customizer.authorizationRequestUri(uriBuilder -> {
uriBuilder.queryParam("param1", "value1");
return uriBuilder.build();
})
);

OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");

assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&client_id=client-id&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=/login/oauth2/code/registration-id&" +
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
"param1=value1");
}

@Test
public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
Mono.just(TestClientRegistrations.clientRegistration()
.scope(OidcScopes.OPENID)
.build()));

this.resolver.setAuthorizationRequestCustomizer(customizer ->
customizer.parameters(params -> {
params.put("appid", params.get("client_id"));
params.remove("client_id");
})
);

OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");

assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=/login/oauth2/code/registration-id&" +
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
"appid=client-id");
}

private OAuth2AuthorizationRequest resolve(String path) {
ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get(path));
return this.resolver.resolve(exchange).block();
Expand Down
Loading