Skip to content

Commit 938dbbf

Browse files
committed
Add OAuth2AuthorizationRequestResolver.resolve(HttpServletRequest,String)
Previously there was a tangle between DefaultOAuth2AuthorizationRequestResolver and OAuth2AuthorizationRequestRedirectFilter with AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME This commit adds a new method that can be used for resolving the OAuth2AuthorizationRequest when the client registration id is known. Issue: gh-4911
1 parent 06df562 commit 938dbbf

File tree

7 files changed

+87
-100
lines changed

7 files changed

+87
-100
lines changed

config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java

+5-10
Original file line numberDiff line numberDiff line change
@@ -192,21 +192,16 @@ public void configureWhenRequestCacheProvidedAndClientAuthorizationRequiredExcep
192192
public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception {
193193
// Override default resolver
194194
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
195-
authorizationRequestResolver = request -> {
196-
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
197-
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
198-
additionalParameters.put("param1", "value1");
199-
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
200-
.additionalParameters(additionalParameters)
201-
.build();
202-
};
195+
authorizationRequestResolver = mock(OAuth2AuthorizationRequestResolver.class);
196+
when(authorizationRequestResolver.resolve(any())).thenAnswer(invocation -> defaultAuthorizationRequestResolver.resolve(invocation.getArgument(0)));
203197

204198
this.spring.register(OAuth2ClientConfig.class).autowire();
205199

206-
MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
200+
this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
207201
.andExpect(status().is3xxRedirection())
208202
.andReturn();
209-
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1&param1=value1");
203+
204+
verify(authorizationRequestResolver).resolve(any());
210205
}
211206

212207
@EnableWebSecurity

config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java

+16-16
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
4545
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
4646
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
47-
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
4847
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
4948
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
5049
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
@@ -78,6 +77,9 @@
7877
import java.util.Map;
7978

8079
import static org.assertj.core.api.Assertions.assertThat;
80+
import static org.mockito.ArgumentMatchers.any;
81+
import static org.mockito.Mockito.mock;
82+
import static org.mockito.Mockito.when;
8183

8284
/**
8385
* Tests for {@link OAuth2LoginConfigurer}.
@@ -236,14 +238,23 @@ public void oauth2LoginConfigLoginProcessingUrl() throws Exception {
236238
@Test
237239
public void oauth2LoginWithCustomAuthorizationRequestParameters() throws Exception {
238240
loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolver.class);
241+
OAuth2AuthorizationRequestResolver resolver = this.context.getBean(
242+
OAuth2LoginConfigCustomAuthorizationRequestResolver.class).resolver;
243+
OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest.authorizationCode()
244+
.authorizationUri("https://accounts.google.com/authorize")
245+
.clientId("client-id")
246+
.state("adsfa")
247+
.authorizationRequestUri("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1")
248+
.build();
249+
when(resolver.resolve(any())).thenReturn(result);
239250

240251
String requestUri = "/oauth2/authorization/google";
241252
this.request = new MockHttpServletRequest("GET", requestUri);
242253
this.request.setServletPath(requestUri);
243254

244255
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
245256

246-
assertThat(this.response.getRedirectedUrl()).matches("https://accounts.google.com/o/oauth2/v2/auth\\?response_type=code&client_id=clientId&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
257+
assertThat(this.response.getRedirectedUrl()).isEqualTo("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
247258
}
248259

249260
// gh-5347
@@ -492,28 +503,17 @@ static class OAuth2LoginConfigCustomAuthorizationRequestResolver extends CommonW
492503
private ClientRegistrationRepository clientRegistrationRepository =
493504
new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION);
494505

506+
OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
507+
495508
@Override
496509
protected void configure(HttpSecurity http) throws Exception {
497510
http
498511
.oauth2Login()
499512
.clientRegistrationRepository(this.clientRegistrationRepository)
500513
.authorizationEndpoint()
501-
.authorizationRequestResolver(this.getAuthorizationRequestResolver());
514+
.authorizationRequestResolver(this.resolver);
502515
super.configure(http);
503516
}
504-
505-
private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() {
506-
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver =
507-
new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, "/oauth2/authorization");
508-
return request -> {
509-
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
510-
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
511-
additionalParameters.put("custom-param1", "custom-value1");
512-
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
513-
.additionalParameters(additionalParameters)
514-
.build();
515-
};
516-
}
517517
}
518518

519519
@EnableWebSecurity

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java

+23-34
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
1919
import org.springframework.security.crypto.keygen.StringKeyGenerator;
20-
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
2120
import org.springframework.security.oauth2.client.registration.ClientRegistration;
2221
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
2322
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -33,8 +32,6 @@
3332
import java.util.HashMap;
3433
import java.util.Map;
3534

36-
import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
37-
3835
/**
3936
* An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
4037
* resolve an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}
@@ -45,6 +42,7 @@
4542
* via it's constructor {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}.
4643
*
4744
* @author Joe Grandja
45+
* @author Rob Winch
4846
* @since 5.1
4947
* @see OAuth2AuthorizationRequestResolver
5048
* @see OAuth2AuthorizationRequestRedirectFilter
@@ -73,6 +71,28 @@ public DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository cl
7371
@Override
7472
public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
7573
String registrationId = this.resolveRegistrationId(request);
74+
String redirectUriAction = getAction(request, "login");
75+
return resolve(request, registrationId, redirectUriAction);
76+
}
77+
78+
@Override
79+
public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId) {
80+
if (registrationId == null) {
81+
return null;
82+
}
83+
String redirectUriAction = getAction(request, "authorize");
84+
return resolve(request, registrationId, redirectUriAction);
85+
}
86+
87+
private String getAction(HttpServletRequest request, String defaultAction) {
88+
String action = request.getParameter("action");
89+
if (action == null) {
90+
return defaultAction;
91+
}
92+
return action;
93+
}
94+
95+
private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId, String redirectUriAction) {
7696
if (registrationId == null) {
7797
return null;
7898
}
@@ -93,7 +113,6 @@ public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
93113
") for Client Registration with Id: " + clientRegistration.getRegistrationId());
94114
}
95115

96-
String redirectUriAction = this.resolveRedirectUriAction(request, clientRegistration);
97116
String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);
98117

99118
Map<String, Object> additionalParameters = new HashMap<>();
@@ -112,43 +131,13 @@ public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
112131
}
113132

114133
private String resolveRegistrationId(HttpServletRequest request) {
115-
// Check for ClientAuthorizationRequiredException which may have been set
116-
// in the request by OAuth2AuthorizationRequestRedirectFilter
117-
ClientAuthorizationRequiredException authzEx =
118-
(ClientAuthorizationRequiredException) request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME);
119-
if (authzEx != null) {
120-
return authzEx.getClientRegistrationId();
121-
}
122134
if (this.authorizationRequestMatcher.matches(request)) {
123135
return this.authorizationRequestMatcher
124136
.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
125137
}
126138
return null;
127139
}
128140

129-
private String resolveRedirectUriAction(HttpServletRequest request, ClientRegistration clientRegistration) {
130-
String action = null;
131-
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
132-
String loginAction = "login";
133-
String authorizeAction = "authorize";
134-
String actionParameter = request.getParameter("action");
135-
if (request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME) != null) {
136-
// Check for ClientAuthorizationRequiredException which may have been set
137-
// in the request by OAuth2AuthorizationRequestRedirectFilter
138-
action = authorizeAction;
139-
} else if (actionParameter == null) {
140-
action = loginAction; // Default
141-
} else {
142-
if (actionParameter.equalsIgnoreCase(loginAction)) {
143-
action = loginAction;
144-
} else {
145-
action = authorizeAction;
146-
}
147-
}
148-
}
149-
return action;
150-
}
151-
152141
private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
153142
// Supported URI variables -> baseUrl, action, registrationId
154143
// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java

+1-6
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
7979
* The default base {@code URI} used for authorization requests.
8080
*/
8181
public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization";
82-
static final String AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME =
83-
ClientAuthorizationRequiredException.class.getName() + ".AUTHORIZATION_REQUIRED_EXCEPTION";
8482
private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
8583
private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
8684
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
@@ -169,17 +167,14 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
169167
.getFirstThrowableOfType(ClientAuthorizationRequiredException.class, causeChain);
170168
if (authzEx != null) {
171169
try {
172-
request.setAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME, authzEx);
173-
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request);
170+
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request, authzEx.getClientRegistrationId());
174171
if (authorizationRequest == null) {
175172
throw authzEx;
176173
}
177174
this.sendRedirectForAuthorization(request, response, authorizationRequest);
178175
this.requestCache.saveRequest(request, response);
179176
} catch (Exception failed) {
180177
this.unsuccessfulRedirectForAuthorization(request, response, failed);
181-
} finally {
182-
request.removeAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME);
183178
}
184179
return;
185180
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestResolver.java

+11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
* Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for resolving Authorization Requests.
2626
*
2727
* @author Joe Grandja
28+
* @author Rob Winch
2829
* @since 5.1
2930
* @see OAuth2AuthorizationRequest
3031
* @see OAuth2AuthorizationRequestRedirectFilter
@@ -40,4 +41,14 @@ public interface OAuth2AuthorizationRequestResolver {
4041
*/
4142
OAuth2AuthorizationRequest resolve(HttpServletRequest request);
4243

44+
/**
45+
* Returns the {@link OAuth2AuthorizationRequest} resolved from
46+
* the provided {@code HttpServletRequest} or {@code null} if not available.
47+
*
48+
* @param request the {@code HttpServletRequest}
49+
* @param clientRegistrationId the clientRegistrationId to use
50+
* @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not available
51+
*/
52+
OAuth2AuthorizationRequest resolve(HttpServletRequest request, String clientRegistrationId);
53+
4354
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java

+5-10
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import org.junit.Before;
1919
import org.junit.Test;
2020
import org.springframework.mock.web.MockHttpServletRequest;
21-
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
2221
import org.springframework.security.oauth2.client.registration.ClientRegistration;
2322
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
2423
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
@@ -28,7 +27,9 @@
2827
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
2928
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
3029

31-
import static org.assertj.core.api.Assertions.*;
30+
import static org.assertj.core.api.Assertions.assertThat;
31+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
32+
import static org.assertj.core.api.Assertions.entry;
3233

3334
/**
3435
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
@@ -139,11 +140,8 @@ public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenResolves
139140
String requestUri = "/path";
140141
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
141142
request.setServletPath(requestUri);
142-
request.setAttribute(
143-
OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
144-
new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
145143

146-
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
144+
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId());
147145
assertThat(authorizationRequest).isNotNull();
148146
assertThat(authorizationRequest.getAdditionalParameters())
149147
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
@@ -213,11 +211,8 @@ public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenRedirect
213211
String requestUri = "/path";
214212
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
215213
request.setServletPath(requestUri);
216-
request.setAttribute(
217-
OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
218-
new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
219214

220-
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
215+
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId());
221216
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
222217
}
223218

0 commit comments

Comments
 (0)