Skip to content

Commit 47cb4eb

Browse files
committed
enable customization of headers in AbstractWebClientReactiveOAuth2AccessTokenResponseClient
adds the possibility to customize the headers of the access token request, similarly to what is done in the AbstractOAuth2AuthorizationGrantRequestEntityConverter Closes spring-projectsgh-10130
1 parent d5c953b commit 47cb4eb

File tree

5 files changed

+326
-6
lines changed

5 files changed

+326
-6
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import reactor.core.publisher.Mono;
2626

27+
import org.springframework.core.convert.converter.Converter;
2728
import org.springframework.http.HttpHeaders;
2829
import org.springframework.http.MediaType;
2930
import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -65,6 +66,8 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
6566

6667
private WebClient webClient = WebClient.builder().build();
6768

69+
private Converter<T, HttpHeaders> headersConverter = this::populateTokenRequestHeaders;
70+
6871
AbstractWebClientReactiveOAuth2AccessTokenResponseClient() {
6972
}
7073

@@ -74,7 +77,12 @@ public Mono<OAuth2AccessTokenResponse> getTokenResponse(T grantRequest) {
7477
// @formatter:off
7578
return Mono.defer(() -> this.webClient.post()
7679
.uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri())
77-
.headers((headers) -> populateTokenRequestHeaders(grantRequest, headers))
80+
.headers((headers) -> {
81+
HttpHeaders headersToAdd = getHeadersConverter().convert(grantRequest);
82+
if (headersToAdd != null) {
83+
headers.addAll(headersToAdd);
84+
}
85+
})
7886
.body(createTokenRequestBody(grantRequest))
7987
.exchange()
8088
.flatMap((response) -> readTokenResponse(grantRequest, response))
@@ -92,9 +100,10 @@ public Mono<OAuth2AccessTokenResponse> getTokenResponse(T grantRequest) {
92100
/**
93101
* Populates the headers for the token request.
94102
* @param grantRequest the grant request
95-
* @param headers the headers to populate
103+
* @return the headers populated for the token request
96104
*/
97-
private void populateTokenRequestHeaders(T grantRequest, HttpHeaders headers) {
105+
private HttpHeaders populateTokenRequestHeaders(T grantRequest) {
106+
HttpHeaders headers = new HttpHeaders();
98107
ClientRegistration clientRegistration = clientRegistration(grantRequest);
99108
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
100109
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
@@ -104,6 +113,7 @@ private void populateTokenRequestHeaders(T grantRequest, HttpHeaders headers) {
104113
String clientSecret = encodeClientCredential(clientRegistration.getClientSecret());
105114
headers.setBasicAuth(clientId, clientSecret);
106115
}
116+
return headers;
107117
}
108118

109119
private static String encodeClientCredential(String clientCredential) {
@@ -230,4 +240,55 @@ public void setWebClient(WebClient webClient) {
230240
this.webClient = webClient;
231241
}
232242

243+
/**
244+
* Returns the {@link Converter} used for converting the
245+
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
246+
* used in the OAuth 2.0 Access Token Request headers.
247+
* @return the {@link Converter} used for converting the
248+
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
249+
*/
250+
final Converter<T, HttpHeaders> getHeadersConverter() {
251+
return this.headersConverter;
252+
}
253+
254+
/**
255+
* Sets the {@link Converter} used for converting the
256+
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
257+
* used in the OAuth 2.0 Access Token Request headers.
258+
* @param headersConverter the {@link Converter} used for converting the
259+
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
260+
* @since 5.6
261+
*/
262+
public final void setHeadersConverter(Converter<T, HttpHeaders> headersConverter) {
263+
Assert.notNull(headersConverter, "headersConverter cannot be null");
264+
this.headersConverter = headersConverter;
265+
}
266+
267+
/**
268+
* Add (compose) the provided {@code headersConverter} to the current
269+
* {@link Converter} used for converting the
270+
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
271+
* used in the OAuth 2.0 Access Token Request headers.
272+
* @param headersConverter the {@link Converter} to add (compose) to the current
273+
* {@link Converter} used for converting the
274+
* {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link HttpHeaders}
275+
* @since 5.6
276+
*/
277+
public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter) {
278+
Assert.notNull(headersConverter, "headersConverter cannot be null");
279+
Converter<T, HttpHeaders> currentHeadersConverter = this.headersConverter;
280+
this.headersConverter = (authorizationGrantRequest) -> {
281+
// Append headers using a Composite Converter
282+
HttpHeaders headers = currentHeadersConverter.convert(authorizationGrantRequest);
283+
if (headers == null) {
284+
headers = new HttpHeaders();
285+
}
286+
HttpHeaders headersToAdd = headersConverter.convert(authorizationGrantRequest);
287+
if (headersToAdd != null) {
288+
headers.addAll(headersToAdd);
289+
}
290+
return headers;
291+
};
292+
}
293+
233294
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,15 +17,18 @@
1717
package org.springframework.security.oauth2.client.endpoint;
1818

1919
import java.time.Instant;
20+
import java.util.Collections;
2021
import java.util.HashMap;
2122
import java.util.Map;
2223

2324
import okhttp3.mockwebserver.MockResponse;
2425
import okhttp3.mockwebserver.MockWebServer;
26+
import okhttp3.mockwebserver.RecordedRequest;
2527
import org.junit.jupiter.api.AfterEach;
2628
import org.junit.jupiter.api.BeforeEach;
2729
import org.junit.jupiter.api.Test;
2830

31+
import org.springframework.core.convert.converter.Converter;
2932
import org.springframework.http.HttpHeaders;
3033
import org.springframework.http.HttpStatus;
3134
import org.springframework.http.MediaType;
@@ -340,4 +343,65 @@ private OAuth2AuthorizationCodeGrantRequest pkceAuthorizationCodeGrantRequest()
340343
return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange);
341344
}
342345

346+
@Test
347+
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
348+
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null))
349+
.withMessage("headersConverter cannot be null");
350+
}
351+
352+
@Test
353+
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
354+
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null))
355+
.withMessage("headersConverter cannot be null");
356+
}
357+
358+
@Test
359+
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
360+
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
361+
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
362+
final HttpHeaders headers = new HttpHeaders();
363+
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
364+
given(addedHeadersConverter.convert(request)).willReturn(headers);
365+
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
366+
// @formatter:off
367+
String accessTokenSuccessResponse = "{\n"
368+
+ " \"access_token\": \"access-token-1234\",\n"
369+
+ " \"token_type\": \"bearer\",\n"
370+
+ " \"expires_in\": \"3600\",\n"
371+
+ " \"scope\": \"openid profile\"\n"
372+
+ "}\n";
373+
// @formatter:on
374+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
375+
this.tokenResponseClient.getTokenResponse(request).block();
376+
verify(addedHeadersConverter).convert(request);
377+
RecordedRequest actualRequest = this.server.takeRequest();
378+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
379+
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
380+
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
381+
}
382+
383+
@Test
384+
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
385+
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
386+
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
387+
final HttpHeaders headers = new HttpHeaders();
388+
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
389+
given(headersConverter.convert(request)).willReturn(headers);
390+
this.tokenResponseClient.setHeadersConverter(headersConverter);
391+
// @formatter:off
392+
String accessTokenSuccessResponse = "{\n"
393+
+ " \"access_token\": \"access-token-1234\",\n"
394+
+ " \"token_type\": \"bearer\",\n"
395+
+ " \"expires_in\": \"3600\",\n"
396+
+ " \"scope\": \"openid profile\"\n"
397+
+ "}\n";
398+
// @formatter:on
399+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
400+
this.tokenResponseClient.getTokenResponse(request).block();
401+
verify(headersConverter).convert(request);
402+
RecordedRequest actualRequest = this.server.takeRequest();
403+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
404+
405+
}
406+
343407
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.net.URLEncoder;
2020
import java.nio.charset.StandardCharsets;
2121
import java.util.Base64;
22+
import java.util.Collections;
2223

2324
import okhttp3.mockwebserver.MockResponse;
2425
import okhttp3.mockwebserver.MockWebServer;
@@ -27,6 +28,7 @@
2728
import org.junit.jupiter.api.BeforeEach;
2829
import org.junit.jupiter.api.Test;
2930

31+
import org.springframework.core.convert.converter.Converter;
3032
import org.springframework.http.HttpHeaders;
3133
import org.springframework.http.MediaType;
3234
import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -212,4 +214,64 @@ private void enqueueJson(String body) {
212214
this.server.enqueue(response);
213215
}
214216

217+
@Test
218+
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
219+
assertThatIllegalArgumentException().isThrownBy(() -> this.client.setHeadersConverter(null))
220+
.withMessage("headersConverter cannot be null");
221+
}
222+
223+
@Test
224+
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
225+
assertThatIllegalArgumentException().isThrownBy(() -> this.client.addHeadersConverter(null))
226+
.withMessage("headersConverter cannot be null");
227+
}
228+
229+
@Test
230+
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
231+
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
232+
this.clientRegistration.build());
233+
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
234+
final HttpHeaders headers = new HttpHeaders();
235+
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
236+
given(addedHeadersConverter.convert(request)).willReturn(headers);
237+
this.client.addHeadersConverter(addedHeadersConverter);
238+
// @formatter:off
239+
enqueueJson("{\n"
240+
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
241+
+ " \"token_type\":\"bearer\",\n"
242+
+ " \"expires_in\":3600,\n"
243+
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
244+
+ "}");
245+
// @formatter:on
246+
this.client.getTokenResponse(request).block();
247+
verify(addedHeadersConverter).convert(request);
248+
RecordedRequest actualRequest = this.server.takeRequest();
249+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
250+
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
251+
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
252+
}
253+
254+
@Test
255+
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
256+
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
257+
this.clientRegistration.build());
258+
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
259+
final HttpHeaders headers = new HttpHeaders();
260+
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
261+
given(headersConverter.convert(request)).willReturn(headers);
262+
this.client.setHeadersConverter(headersConverter);
263+
// @formatter:off
264+
enqueueJson("{\n"
265+
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
266+
+ " \"token_type\":\"bearer\",\n"
267+
+ " \"expires_in\":3600,\n"
268+
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
269+
+ "}");
270+
// @formatter:on
271+
this.client.getTokenResponse(request).block();
272+
verify(headersConverter).convert(request);
273+
RecordedRequest actualRequest = this.server.takeRequest();
274+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
275+
}
276+
215277
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
1717
package org.springframework.security.oauth2.client.endpoint;
1818

1919
import java.time.Instant;
20+
import java.util.Collections;
2021

2122
import okhttp3.mockwebserver.MockResponse;
2223
import okhttp3.mockwebserver.MockWebServer;
@@ -25,6 +26,7 @@
2526
import org.junit.jupiter.api.BeforeEach;
2627
import org.junit.jupiter.api.Test;
2728

29+
import org.springframework.core.convert.converter.Converter;
2830
import org.springframework.http.HttpHeaders;
2931
import org.springframework.http.HttpMethod;
3032
import org.springframework.http.MediaType;
@@ -38,6 +40,9 @@
3840
import static org.assertj.core.api.Assertions.assertThat;
3941
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
4042
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
43+
import static org.mockito.BDDMockito.given;
44+
import static org.mockito.Mockito.mock;
45+
import static org.mockito.Mockito.verify;
4146

4247
/**
4348
* Tests for {@link WebClientReactivePasswordTokenResponseClient}.
@@ -213,4 +218,66 @@ private MockResponse jsonResponse(String json) {
213218
// @formatter:on
214219
}
215220

221+
@Test
222+
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
223+
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null))
224+
.withMessage("headersConverter cannot be null");
225+
}
226+
227+
@Test
228+
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
229+
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null))
230+
.withMessage("headersConverter cannot be null");
231+
}
232+
233+
@Test
234+
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
235+
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
236+
this.username, this.password);
237+
Converter<OAuth2PasswordGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
238+
final HttpHeaders headers = new HttpHeaders();
239+
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
240+
given(addedHeadersConverter.convert(request)).willReturn(headers);
241+
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
242+
// @formatter:off
243+
String accessTokenSuccessResponse = "{\n"
244+
+ " \"access_token\": \"access-token-1234\",\n"
245+
+ " \"token_type\": \"bearer\",\n"
246+
+ " \"expires_in\": \"3600\",\n"
247+
+ " \"scope\": \"read\"\n"
248+
+ "}\n";
249+
// @formatter:on
250+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
251+
this.tokenResponseClient.getTokenResponse(request).block();
252+
verify(addedHeadersConverter).convert(request);
253+
RecordedRequest actualRequest = this.server.takeRequest();
254+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
255+
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
256+
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
257+
}
258+
259+
@Test
260+
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
261+
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
262+
this.username, this.password);
263+
Converter<OAuth2PasswordGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
264+
final HttpHeaders headers = new HttpHeaders();
265+
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
266+
given(headersConverter.convert(request)).willReturn(headers);
267+
this.tokenResponseClient.setHeadersConverter(headersConverter);
268+
// @formatter:off
269+
String accessTokenSuccessResponse = "{\n"
270+
+ " \"access_token\": \"access-token-1234\",\n"
271+
+ " \"token_type\": \"bearer\",\n"
272+
+ " \"expires_in\": \"3600\",\n"
273+
+ " \"scope\": \"read\"\n"
274+
+ "}\n";
275+
// @formatter:on
276+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
277+
this.tokenResponseClient.getTokenResponse(request).block();
278+
verify(headersConverter).convert(request);
279+
RecordedRequest actualRequest = this.server.takeRequest();
280+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
281+
}
282+
216283
}

0 commit comments

Comments
 (0)