Skip to content

Commit 1433393

Browse files
committed
Provide more flexibility on when to display consent page
1 parent d151568 commit 1433393

File tree

3 files changed

+188
-22
lines changed

3 files changed

+188
-22
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationContext.java

+45
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.util.function.Consumer;
2222

2323
import org.springframework.lang.Nullable;
24+
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
25+
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
2426
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
2527
import org.springframework.util.Assert;
2628

@@ -63,6 +65,27 @@ public RegisteredClient getRegisteredClient() {
6365
return get(RegisteredClient.class);
6466
}
6567

68+
/**
69+
* Returns the {@link OAuth2AuthorizationRequest oauth2 authorization request}.
70+
*
71+
* @return the {@link OAuth2AuthorizationRequest}
72+
*/
73+
@Nullable
74+
public OAuth2AuthorizationRequest getOAuth2AuthorizationRequest() {
75+
return get(OAuth2AuthorizationRequest.class);
76+
}
77+
78+
/**
79+
* Returns the {@link OAuth2AuthorizationConsent oauth2 authorization consent}.
80+
*
81+
* @return the {@link OAuth2AuthorizationConsent}
82+
*/
83+
@Nullable
84+
public OAuth2AuthorizationConsent getOAuth2AuthorizationConsent() {
85+
return get(OAuth2AuthorizationConsent.class);
86+
}
87+
88+
6689
/**
6790
* Constructs a new {@link Builder} with the provided {@link OAuth2AuthorizationCodeRequestAuthenticationToken}.
6891
*
@@ -92,6 +115,28 @@ public Builder registeredClient(RegisteredClient registeredClient) {
92115
return put(RegisteredClient.class, registeredClient);
93116
}
94117

118+
/**
119+
* Sets the {@link OAuth2AuthorizationRequest oauth2 authorization request}.
120+
*
121+
* @param authorizationRequest the {@link OAuth2AuthorizationRequest}
122+
* @return the {@link Builder} for further configuration
123+
* @since 1.3.0
124+
*/
125+
public Builder authorizationRequest(OAuth2AuthorizationRequest authorizationRequest) {
126+
return put(OAuth2AuthorizationRequest.class, authorizationRequest);
127+
}
128+
129+
/**
130+
* Sets the {@link OAuth2AuthorizationConsent oauth2 authorization consent}.
131+
*
132+
* @param authorizationConsent the {@link OAuth2AuthorizationConsent}
133+
* @return the {@link Builder} for further configuration
134+
* @since 1.3.0
135+
*/
136+
public Builder authorizationConsent(OAuth2AuthorizationConsent authorizationConsent) {
137+
return put(OAuth2AuthorizationConsent.class, authorizationConsent);
138+
}
139+
95140
/**
96141
* Builds a new {@link OAuth2AuthorizationCodeRequestAuthenticationContext}.
97142
*

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

+58-22
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.Base64;
2020
import java.util.Set;
2121
import java.util.function.Consumer;
22+
import java.util.function.Predicate;
2223

2324
import org.apache.commons.logging.Log;
2425
import org.apache.commons.logging.LogFactory;
@@ -80,6 +81,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
8081
private OAuth2TokenGenerator<OAuth2AuthorizationCode> authorizationCodeGenerator = new OAuth2AuthorizationCodeGenerator();
8182
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator =
8283
new OAuth2AuthorizationCodeRequestAuthenticationValidator();
84+
private Predicate<OAuth2AuthorizationCodeRequestAuthenticationContext> requiresAuthorizationConsent;
8385

8486
/**
8587
* Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationProvider} using the provided parameters.
@@ -96,6 +98,7 @@ public OAuth2AuthorizationCodeRequestAuthenticationProvider(RegisteredClientRepo
9698
this.registeredClientRepository = registeredClientRepository;
9799
this.authorizationService = authorizationService;
98100
this.authorizationConsentService = authorizationConsentService;
101+
this.requiresAuthorizationConsent = this::requireAuthorizationConsent;
99102
}
100103

101104
@Override
@@ -171,7 +174,19 @@ public Authentication authenticate(Authentication authentication) throws Authent
171174
OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService.findById(
172175
registeredClient.getId(), principal.getName());
173176

174-
if (requireAuthorizationConsent(registeredClient, authorizationRequest, currentAuthorizationConsent)) {
177+
OAuth2AuthorizationCodeRequestAuthenticationContext.Builder authenticationContextBuilder =
178+
OAuth2AuthorizationCodeRequestAuthenticationContext.with(authorizationCodeRequestAuthentication)
179+
.registeredClient(registeredClient)
180+
.authorizationRequest(authorizationRequest);
181+
182+
if (currentAuthorizationConsent != null) {
183+
authenticationContextBuilder.authorizationConsent(currentAuthorizationConsent);
184+
}
185+
186+
OAuth2AuthorizationCodeRequestAuthenticationContext contextWithAuthorizationRequestAndAuthorizationConsent =
187+
authenticationContextBuilder.build();
188+
189+
if (requiresAuthorizationConsent.test(contextWithAuthorizationRequestAndAuthorizationConsent)) {
175190
String state = DEFAULT_STATE_GENERATOR.generateKey();
176191
OAuth2Authorization authorization = authorizationBuilder(registeredClient, principal, authorizationRequest)
177192
.attribute(OAuth2ParameterNames.STATE, state)
@@ -264,7 +279,48 @@ public void setAuthenticationValidator(Consumer<OAuth2AuthorizationCodeRequestAu
264279
this.authenticationValidator = authenticationValidator;
265280
}
266281

267-
private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient, Authentication principal,
282+
/**
283+
* Sets the {@link Predicate} used to determine if authorization consent is required.
284+
*
285+
* <p>
286+
* The {@link OAuth2AuthorizationCodeRequestAuthenticationContext} gives the predicate access to the {@link OAuth2AuthorizationCodeRequestAuthenticationToken},
287+
* as well as, the following context attributes:
288+
* {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getRegisteredClient()} containing {@link RegisteredClient} used to make the request.
289+
* {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getOAuth2AuthorizationRequest()} containing {@link OAuth2AuthorizationRequest}.
290+
* {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getOAuth2AuthorizationConsent()} containing {@link OAuth2AuthorizationConsent} granted in the request.
291+
*
292+
* @param requiresAuthorizationConsent the {@link Predicate} that determines if authorization consent is required.
293+
* @since 1.3.0
294+
*/
295+
public void setRequiresAuthorizationConsent(Predicate<OAuth2AuthorizationCodeRequestAuthenticationContext> requiresAuthorizationConsent) {
296+
Assert.notNull(requiresAuthorizationConsent, "requiresAuthorizationConsent cannot be null");
297+
this.requiresAuthorizationConsent = requiresAuthorizationConsent;
298+
}
299+
300+
private boolean requireAuthorizationConsent(OAuth2AuthorizationCodeRequestAuthenticationContext context) {
301+
RegisteredClient registeredClient = context.getRegisteredClient();
302+
if (!registeredClient.getClientSettings().isRequireAuthorizationConsent()) {
303+
return false;
304+
}
305+
306+
OAuth2AuthorizationRequest authorizationRequest = context.getOAuth2AuthorizationRequest();
307+
// 'openid' scope does not require consent
308+
if (authorizationRequest.getScopes().contains(OidcScopes.OPENID) &&
309+
authorizationRequest.getScopes().size() == 1) {
310+
return false;
311+
}
312+
313+
OAuth2AuthorizationConsent authorizationConsent = context.getOAuth2AuthorizationConsent();
314+
if (authorizationConsent != null &&
315+
authorizationConsent.getScopes().containsAll(authorizationRequest.getScopes())) {
316+
return false;
317+
}
318+
319+
return true;
320+
}
321+
322+
private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient,
323+
Authentication principal,
268324
OAuth2AuthorizationRequest authorizationRequest) {
269325
return OAuth2Authorization.withRegisteredClient(registeredClient)
270326
.principalName(principal.getName())
@@ -295,26 +351,6 @@ private static OAuth2TokenContext createAuthorizationCodeTokenContext(
295351
return tokenContextBuilder.build();
296352
}
297353

298-
private static boolean requireAuthorizationConsent(RegisteredClient registeredClient,
299-
OAuth2AuthorizationRequest authorizationRequest, OAuth2AuthorizationConsent authorizationConsent) {
300-
301-
if (!registeredClient.getClientSettings().isRequireAuthorizationConsent()) {
302-
return false;
303-
}
304-
// 'openid' scope does not require consent
305-
if (authorizationRequest.getScopes().contains(OidcScopes.OPENID) &&
306-
authorizationRequest.getScopes().size() == 1) {
307-
return false;
308-
}
309-
310-
if (authorizationConsent != null &&
311-
authorizationConsent.getScopes().containsAll(authorizationRequest.getScopes())) {
312-
return false;
313-
}
314-
315-
return true;
316-
}
317-
318354
private static boolean isPrincipalAuthenticated(Authentication principal) {
319355
return principal != null &&
320356
!AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) &&

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

+85
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.Map;
2222
import java.util.Set;
2323
import java.util.function.Consumer;
24+
import java.util.function.Predicate;
2425

2526
import org.junit.jupiter.api.BeforeEach;
2627
import org.junit.jupiter.api.Test;
@@ -72,6 +73,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
7273
private OAuth2AuthorizationConsentService authorizationConsentService;
7374
private OAuth2AuthorizationCodeRequestAuthenticationProvider authenticationProvider;
7475
private TestingAuthenticationToken principal;
76+
private Predicate<OAuth2AuthorizationCodeRequestAuthenticationContext> requiresAuthorizationConsent;
7577

7678
@BeforeEach
7779
public void setUp() {
@@ -129,6 +131,13 @@ public void setAuthenticationValidatorWhenNullThenThrowIllegalArgumentException(
129131
.hasMessage("authenticationValidator cannot be null");
130132
}
131133

134+
@Test
135+
public void setRequiresAuthorizationConsentWhenNullThenThrowIllegalArgumentException() {
136+
assertThatThrownBy(() -> this.authenticationProvider.setRequiresAuthorizationConsent(null))
137+
.isInstanceOf(IllegalArgumentException.class)
138+
.hasMessage("requiresAuthorizationConsent cannot be null");
139+
}
140+
132141
@Test
133142
public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
134143
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@@ -443,6 +452,82 @@ public void authenticateWhenRequireAuthorizationConsentThenReturnAuthorizationCo
443452
assertThat(authenticationResult.isAuthenticated()).isTrue();
444453
}
445454

455+
@Test
456+
public void authenticateWhenRequireAuthorizationConsentAndRequiresAuthorizationConsentPredicateTrueThenReturnAuthorizationConsent() {
457+
this.authenticationProvider.setRequiresAuthorizationConsent((authenticationContext) -> true);
458+
459+
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
460+
.clientSettings(ClientSettings.builder().requireAuthorizationConsent(true).build())
461+
.build();
462+
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
463+
.thenReturn(registeredClient);
464+
465+
String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[0];
466+
OAuth2AuthorizationCodeRequestAuthenticationToken authentication =
467+
new OAuth2AuthorizationCodeRequestAuthenticationToken(
468+
AUTHORIZATION_URI, registeredClient.getClientId(), principal,
469+
redirectUri, STATE, registeredClient.getScopes(), null);
470+
471+
OAuth2AuthorizationConsentAuthenticationToken authenticationResult =
472+
(OAuth2AuthorizationConsentAuthenticationToken) this.authenticationProvider.authenticate(authentication);
473+
474+
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
475+
verify(this.authorizationService).save(authorizationCaptor.capture());
476+
OAuth2Authorization authorization = authorizationCaptor.getValue();
477+
478+
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
479+
assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
480+
assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
481+
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(authentication.getAuthorizationUri());
482+
assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId());
483+
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(authentication.getRedirectUri());
484+
assertThat(authorizationRequest.getScopes()).isEqualTo(authentication.getScopes());
485+
assertThat(authorizationRequest.getState()).isEqualTo(authentication.getState());
486+
assertThat(authorizationRequest.getAdditionalParameters()).isEqualTo(authentication.getAdditionalParameters());
487+
488+
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
489+
assertThat(authorization.getPrincipalName()).isEqualTo(this.principal.getName());
490+
assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
491+
assertThat(authorization.<Authentication>getAttribute(Principal.class.getName())).isEqualTo(this.principal);
492+
String state = authorization.getAttribute(OAuth2ParameterNames.STATE);
493+
assertThat(state).isNotNull();
494+
assertThat(state).isNotEqualTo(authentication.getState());
495+
496+
assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId());
497+
assertThat(authenticationResult.getPrincipal()).isEqualTo(this.principal);
498+
assertThat(authenticationResult.getAuthorizationUri()).isEqualTo(authorizationRequest.getAuthorizationUri());
499+
assertThat(authenticationResult.getScopes()).isEmpty();
500+
assertThat(authenticationResult.getState()).isEqualTo(state);
501+
assertThat(authenticationResult.isAuthenticated()).isTrue();
502+
}
503+
504+
@Test
505+
public void authenticateWhenRequireAuthorizationConsentAndRequiresAuthorizationConsentPredicateFalseThenAuthorizationConsentNotRequired() {
506+
this.authenticationProvider.setRequiresAuthorizationConsent((authenticationContext) -> false);
507+
508+
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
509+
.clientSettings(ClientSettings.builder().requireAuthorizationConsent(true).build())
510+
.scopes(scopes -> {
511+
scopes.clear();
512+
scopes.add(OidcScopes.OPENID);
513+
scopes.add(OidcScopes.EMAIL);
514+
})
515+
.build();
516+
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
517+
.thenReturn(registeredClient);
518+
519+
String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[1];
520+
OAuth2AuthorizationCodeRequestAuthenticationToken authentication =
521+
new OAuth2AuthorizationCodeRequestAuthenticationToken(
522+
AUTHORIZATION_URI, registeredClient.getClientId(), principal,
523+
redirectUri, STATE, registeredClient.getScopes(), null);
524+
525+
OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult =
526+
(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication);
527+
528+
assertAuthorizationCodeRequestWithAuthorizationCodeResult(registeredClient, authentication, authenticationResult);
529+
}
530+
446531
@Test
447532
public void authenticateWhenRequireAuthorizationConsentAndOnlyOpenidScopeRequestedThenAuthorizationConsentNotRequired() {
448533
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()

0 commit comments

Comments
 (0)