Skip to content

Introduce OAuth2AuthorizationRequest.attributes #6508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -146,14 +146,14 @@ public void configureWhenAuthorizationCodeResponseSuccessThenAuthorizedClientSav
this.spring.register(OAuth2ClientConfig.class).autowire();

// Setup the Authorization Request in the session
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId());
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId());
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri())
.clientId(this.registration1.getClientId())
.redirectUri("http://localhost/client-1")
.state("state")
.additionalParameters(additionalParameters)
.attributes(attributes)
.build();

AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -473,7 +473,7 @@ private OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(ClientRegist
.clientId(registration.getClientId())
.state("state123")
.redirectUri("http://localhost")
.additionalParameters(
.attributes(
Collections.singletonMap(
OAuth2ParameterNames.REGISTRATION_ID,
registration.getRegistrationId()))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -115,16 +115,16 @@ private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String re

String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);

Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());

OAuth2AuthorizationRequest authorizationRequest = builder
.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr)
.scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.additionalParameters(additionalParameters)
.attributes(attributes)
.build();

return authorizationRequest;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -161,7 +161,7 @@ private void processAuthorizationResponse(HttpServletRequest request, HttpServle
OAuth2AuthorizationRequest authorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request, response);

String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID);
String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID);
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);

MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -165,7 +165,7 @@ public Authentication attemptAuthentication(HttpServletRequest request, HttpServ
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID);
String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID);
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
if (clientRegistration == null) {
OAuth2Error oauth2Error = new OAuth2Error(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -118,9 +118,8 @@ private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchan
String redirectUriStr = this
.expandRedirectUri(exchange.getRequest(), clientRegistration);

Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID,
clientRegistration.getRegistrationId());
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());

OAuth2AuthorizationRequest.Builder builder;
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
Expand All @@ -139,7 +138,7 @@ else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizat
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr).scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.additionalParameters(additionalParameters)
.attributes(attributes)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -85,9 +85,9 @@ private <T> Mono<T> oauth2AuthorizationException(String errorCode) {

private Mono<OAuth2AuthorizationCodeAuthenticationToken> authenticationRequest(ServerWebExchange exchange, OAuth2AuthorizationRequest authorizationRequest) {
return Mono.just(authorizationRequest)
.map(OAuth2AuthorizationRequest::getAdditionalParameters)
.flatMap(additionalParams -> {
String id = (String) additionalParams.get(OAuth2ParameterNames.REGISTRATION_ID);
.map(OAuth2AuthorizationRequest::getAttributes)
.flatMap(attributes -> {
String id = (String) attributes.get(OAuth2ParameterNames.REGISTRATION_ID);
if (id == null) {
return oauth2AuthorizationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -105,7 +105,8 @@ public void resolveWhenAuthorizationRequestWithValidClientThenResolves() {
.isEqualTo("http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
assertThat(authorizationRequest.getScopes()).isEqualTo(clientRegistration.getScopes());
assertThat(authorizationRequest.getState()).isNotNull();
assertThat(authorizationRequest.getAdditionalParameters())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is to fix gh-5940 we should assert that the additional parameters do not contain the registration id (just because it is in the attributes -- the next assertion -- doesn't mean it isn't also in the parameters)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Applied and pushed.

assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OAuth2ParameterNames.REGISTRATION_ID);
assertThat(authorizationRequest.getAttributes())
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
Expand All @@ -123,7 +124,7 @@ public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenResolves

OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId());
assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAdditionalParameters())
assertThat(authorizationRequest.getAttributes())
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -402,15 +402,15 @@ public void doFilterWhenAuthorizationResponseHasNonDefaultPortThenRedirectUriMat

private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
ClientRegistration registration, String state) {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(registration.getProviderDetails().getAuthorizationUri())
.clientId(registration.getClientId())
.redirectUri(expandRedirectUri(request, registration))
.scopes(registration.getScopes())
.state(state)
.additionalParameters(additionalParameters)
.attributes(attributes)
.build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -74,7 +74,7 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest {
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state("state")
.additionalParameters(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, this.clientRegistrationId));
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, this.clientRegistrationId));

private final MockServerHttpRequest.BaseBuilder<?> request = MockServerHttpRequest.get("/");

Expand All @@ -95,8 +95,8 @@ public void applyWhenAuthorizationRequestEmptyThenOAuth2AuthorizationException()
}

@Test
public void applyWhenAdditionalParametersMissingThenOAuth2AuthorizationException() {
this.authorizationRequest.additionalParameters(Collections.emptyMap());
public void applyWhenAttributesMissingThenOAuth2AuthorizationException() {
this.authorizationRequest.attributes(Collections.emptyMap());
when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build()));

assertThatThrownBy(() -> applyConverter())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,15 @@
*/
package org.springframework.security.oauth2.core.endpoint;

import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
Expand All @@ -25,15 +34,6 @@
import java.util.Set;
import java.util.stream.Collectors;

import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

/**
* A representation of an OAuth 2.0 Authorization Request
* for the authorization code grant type or implicit grant type.
Expand All @@ -56,6 +56,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
private String state;
private Map<String, Object> additionalParameters;
private String authorizationRequestUri;
private Map<String, Object> attributes;

private OAuth2AuthorizationRequest() {
}
Expand Down Expand Up @@ -132,6 +133,29 @@ public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}

/**
* Returns the attributes associated to the request.
*
* @since 5.2
* @return a {@code Map} of the attributes associated to the request
*/
public Map<String, Object> getAttributes() {
return this.attributes;
}

/**
* Returns the value of an attribute associated to the request, or {@code null} if not available.
*
* @since 5.2
* @param name the name of the attribute
* @param <T> the type of the attribute
* @return the value of the attribute associated to the request
*/
@SuppressWarnings("unchecked")
public <T> T getAttribute(String name) {
return (T) this.getAttributes().get(name);
}

/**
* Returns the {@code URI} string representation of the OAuth 2.0 Authorization Request.
*
Expand Down Expand Up @@ -181,7 +205,8 @@ public static Builder from(OAuth2AuthorizationRequest authorizationRequest) {
.redirectUri(authorizationRequest.getRedirectUri())
.scopes(authorizationRequest.getScopes())
.state(authorizationRequest.getState())
.additionalParameters(authorizationRequest.getAdditionalParameters());
.additionalParameters(authorizationRequest.getAdditionalParameters())
.attributes(authorizationRequest.getAttributes());
}

/**
Expand All @@ -197,6 +222,7 @@ public static class Builder {
private String state;
private Map<String, Object> additionalParameters;
private String authorizationRequestUri;
private Map<String, Object> attributes;

private Builder(AuthorizationGrantType authorizationGrantType) {
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
Expand Down Expand Up @@ -288,6 +314,18 @@ public Builder additionalParameters(Map<String, Object> additionalParameters) {
return this;
}

/**
* Sets the attributes associated to the request.
*
* @since 5.2
* @param attributes the attributes associated to the request
* @return the {@link Builder}
*/
public Builder attributes(Map<String, Object> attributes) {
this.attributes = attributes;
return this;
}

/**
* Sets the {@code URI} string representation of the OAuth 2.0 Authorization Request.
*
Expand Down Expand Up @@ -332,6 +370,9 @@ public OAuth2AuthorizationRequest build() {
authorizationRequest.authorizationRequestUri =
StringUtils.hasText(this.authorizationRequestUri) ?
this.authorizationRequestUri : this.buildAuthorizationRequestUri();
authorizationRequest.attributes = Collections.unmodifiableMap(
CollectionUtils.isEmpty(this.attributes) ?
Collections.emptyMap() : new LinkedHashMap<>(this.attributes));

return authorizationRequest;
}
Expand All @@ -351,9 +392,7 @@ private String buildAuthorizationRequestUri() {
parameters.set(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri);
}
if (!CollectionUtils.isEmpty(this.additionalParameters)) {
this.additionalParameters.entrySet().stream()
.filter(e -> !e.getKey().equals(OAuth2ParameterNames.REGISTRATION_ID))
.forEach(e -> parameters.set(e.getKey(), e.getValue().toString()));
this.additionalParameters.forEach((k, v) -> parameters.set(k, v.toString()));
}

return UriComponentsBuilder.fromHttpUrl(this.authorizationUri)
Expand Down
Loading