Skip to content

Allow in-memory authorized client services to be constructed with a map #5994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,7 +20,6 @@
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.util.Assert;

import java.util.Base64;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand All @@ -29,15 +28,16 @@
* {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory.
*
* @author Joe Grandja
* @author Vedran Pavic
* @since 5.0
* @see OAuth2AuthorizedClientService
* @see OAuth2AuthorizedClient
* @see ClientRegistration
* @see Authentication
*/
public final class InMemoryOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService {
private final Map<String, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
private final ClientRegistrationRepository clientRegistrationRepository;
private Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();

/**
* Constructs an {@code InMemoryOAuth2AuthorizedClientService} using the provided parameters.
Expand All @@ -49,23 +49,33 @@ public InMemoryOAuth2AuthorizedClientService(ClientRegistrationRepository client
this.clientRegistrationRepository = clientRegistrationRepository;
}

/**
* Sets the map of authorized clients to use.
* @param authorizedClients the map of authorized clients
*/
public void setAuthorizedClients(Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients) {
Assert.notNull(authorizedClients, "authorizedClients cannot be null");
this.authorizedClients = authorizedClients;
}

@Override
@SuppressWarnings("unchecked")
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, String principalName) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty");
ClientRegistration registration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
if (registration == null) {
return null;
}
return (T) this.authorizedClients.get(this.getIdentifier(registration, principalName));
return (T) this.authorizedClients.get(OAuth2AuthorizedClientId.create(registration, principalName));
}

@Override
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
Assert.notNull(principal, "principal cannot be null");
this.authorizedClients.put(this.getIdentifier(
authorizedClient.getClientRegistration(), principal.getName()), authorizedClient);
this.authorizedClients.put(OAuth2AuthorizedClientId.create(authorizedClient.getClientRegistration(),
principal.getName()), authorizedClient);
}

@Override
Expand All @@ -74,12 +84,8 @@ public void removeAuthorizedClient(String clientRegistrationId, String principal
Assert.hasText(principalName, "principalName cannot be empty");
ClientRegistration registration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
if (registration != null) {
this.authorizedClients.remove(this.getIdentifier(registration, principalName));
this.authorizedClients.remove(OAuth2AuthorizedClientId.create(registration, principalName));
}
}

private String getIdentifier(ClientRegistration registration, String principalName) {
String identifier = "[" + registration.getRegistrationId() + "][" + principalName + "]";
return Base64.getEncoder().encodeToString(identifier.getBytes());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package org.springframework.security.oauth2.client;

import java.util.Base64;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand All @@ -31,14 +30,15 @@
* {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory.
*
* @author Rob Winch
* @author Vedran Pavic
* @since 5.1
* @see OAuth2AuthorizedClientService
* @see OAuth2AuthorizedClient
* @see ClientRegistration
* @see Authentication
*/
public final class InMemoryReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService {
private final Map<String, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
private final Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();;
private final ReactiveClientRegistrationRepository clientRegistrationRepository;

/**
Expand All @@ -52,10 +52,12 @@ public InMemoryReactiveOAuth2AuthorizedClientService(ReactiveClientRegistrationR
}

@Override
@SuppressWarnings("unchecked")
public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(String clientRegistrationId, String principalName) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty");
return (Mono<T>) getIdentifier(clientRegistrationId, principalName)
return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.map(clientRegistration -> OAuth2AuthorizedClientId.create(clientRegistration, principalName))
.flatMap(identifier -> Mono.justOrEmpty(this.authorizedClients.get(identifier)));
}

Expand All @@ -64,7 +66,8 @@ public Mono<Void> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient,
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
Assert.notNull(principal, "principal cannot be null");
return Mono.fromRunnable(() -> {
String identifier = this.getIdentifier(authorizedClient.getClientRegistration(), principal.getName());
OAuth2AuthorizedClientId identifier = OAuth2AuthorizedClientId.create(
authorizedClient.getClientRegistration(), principal.getName());
this.authorizedClients.put(identifier, authorizedClient);
});
}
Expand All @@ -73,18 +76,10 @@ public Mono<Void> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient,
public Mono<Void> removeAuthorizedClient(String clientRegistrationId, String principalName) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty");
return this.getIdentifier(clientRegistrationId, principalName)
.doOnNext(identifier -> this.authorizedClients.remove(identifier))
.then(Mono.empty());
}

private Mono<String> getIdentifier(String clientRegistrationId, String principalName) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.map(registration -> getIdentifier(registration, principalName));
.map(clientRegistration -> OAuth2AuthorizedClientId.create(clientRegistration, principalName))
.doOnNext(this.authorizedClients::remove)
.then(Mono.empty());
}

private String getIdentifier(ClientRegistration registration, String principalName) {
String identifier = "[" + registration.getRegistrationId() + "][" + principalName + "]";
return Base64.getEncoder().encodeToString(identifier.getBytes());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client;

import java.io.Serializable;
import java.util.Objects;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.util.Assert;

/**
* The identifier for {@link OAuth2AuthorizedClient}.
*
* @author Vedran Pavic
* @since 5.2
* @see OAuth2AuthorizedClient
* @see OAuth2AuthorizedClientService
*/
public final class OAuth2AuthorizedClientId implements Serializable {

private final String clientRegistrationId;

private final String principalName;

private OAuth2AuthorizedClientId(String clientRegistrationId, String principalName) {
Assert.notNull(clientRegistrationId, "clientRegistrationId cannot be null");
Assert.notNull(principalName, "principalName cannot be null");
this.clientRegistrationId = clientRegistrationId;
this.principalName = principalName;
}

/**
* Factory method for creating new {@link OAuth2AuthorizedClientId} using
* {@link ClientRegistration} and principal name.
* @param clientRegistration the client registration
* @param principalName the principal name
* @return the new authorized client id
*/
public static OAuth2AuthorizedClientId create(ClientRegistration clientRegistration,
String principalName) {
return new OAuth2AuthorizedClientId(clientRegistration.getRegistrationId(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should Assert.notNull on clientRegistration and Assert.hasText on principalName. Applying this update would than make Assert.notNull redundant in OAuth2AuthorizedClientId constructor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the parameter ClientRegistration clientRegistration should be changed to String clientRegistrationId, which aligns with OAuth2AuthorizedClientService.loadAuthorizedClient(String clientRegistrationId, String principalName).

I'm also curious on why this factory method is needed? Alternatively, we can make the constructor public. What are your throughts?

principalName);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
OAuth2AuthorizedClientId that = (OAuth2AuthorizedClientId) obj;
return Objects.equals(this.clientRegistrationId, that.clientRegistrationId)
&& Objects.equals(this.principalName, that.principalName);
}

@Override
public int hashCode() {
return Objects.hash(this.clientRegistrationId, this.principalName);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to implement toString as well?

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,11 @@
*/
package org.springframework.security.oauth2.client;

import java.util.Collections;
import java.util.Map;

import org.junit.Test;

import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
Expand All @@ -24,13 +28,17 @@
import org.springframework.security.oauth2.core.OAuth2AccessToken;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Tests for {@link InMemoryOAuth2AuthorizedClientService}.
*
* @author Joe Grandja
* @author Vedran Pavic
*/
public class InMemoryOAuth2AuthorizedClientServiceTests {
private String principalName1 = "principal-1";
Expand All @@ -57,6 +65,30 @@ public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArg
new InMemoryOAuth2AuthorizedClientService(null);
}

@Test
public void constructorWhenAuthorizedClientsIsNullThenIllegalArgumentException() {
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.authorizedClientService.setAuthorizedClients(null))
.withMessage("authorizedClients cannot be null");
}

@Test
public void constructorWhenAuthorizedClientsIsEmptyMapThenRepositoryUsingSuppliedAuthorizedClients() {
String registrationId = this.registration3.getRegistrationId();

Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
OAuth2AuthorizedClientId.create(this.registration3, this.principalName1),
mock(OAuth2AuthorizedClient.class));
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);

InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
this.clientRegistrationRepository);
authorizedClientService.setAuthorizedClients(authorizedClients);
assertThat((OAuth2AuthorizedClient) authorizedClientService.loadAuthorizedClient(
registrationId, this.principalName1)).isNotNull();
}

@Test(expected = IllegalArgumentException.class)
public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
this.authorizedClientService.loadAuthorizedClient(null, this.principalName1);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client;

import org.junit.Test;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Tests for {@link OAuth2AuthorizedClientId}.
*
* @author Vedran Pavic
*/
public class OAuth2AuthorizedClientIdTests {

@Test
public void equalsWhenSameRegistrationIdAndPrincipalThenShouldReturnTrue() {
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
"test-principal");
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
"test-principal");
assertThat(id1.equals(id2)).isTrue();
}

@Test
public void equalsWhenDifferentRegistrationIdAndSamePrincipalThenShouldReturnFalse() {
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client1"),
"test-principal");
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client2"),
"test-principal");
assertThat(id1.equals(id2)).isFalse();
}

@Test
public void equalsWhenSameRegistrationIdAndDifferentPrincipalThenShouldReturnFalse() {
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
"test-principal1");
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
"test-principal2");
assertThat(id1.equals(id2)).isFalse();
}

@Test
public void hashCodeWhenSameRegistrationIdAndPrincipalThenShouldReturnSame() {
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
"test-principal");
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
"test-principal");
assertThat(id1.hashCode()).isEqualTo(id2.hashCode());
}

@Test
public void hashCodeWhenDifferentRegistrationIdAndSamePrincipalThenShouldNotReturnSame() {
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client1"),
"test-principal");
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client2"),
"test-principal");
assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode());
}

@Test
public void hashCodeWhenSameRegistrationIdAndDifferentPrincipalThenShouldNotReturnSame() {
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
"test-principal1");
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
"test-principal2");
assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode());
}

private static ClientRegistration testClientRegistration(String registrationId) {
return ClientRegistration.withRegistrationId(registrationId).clientId("id").clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
.authorizationUri("http://example.com/authorize").tokenUri("http://example.com/token").build();
}

}