|
16 | 16 |
|
17 | 17 | package org.springframework.security.oauth2.client.registration;
|
18 | 18 |
|
| 19 | +import java.lang.reflect.Field; |
| 20 | +import java.lang.reflect.Modifier; |
| 21 | +import java.util.Arrays; |
19 | 22 | import java.util.Collections;
|
20 | 23 | import java.util.LinkedHashMap;
|
| 24 | +import java.util.List; |
21 | 25 | import java.util.Map;
|
22 | 26 | import java.util.Set;
|
23 | 27 | import java.util.stream.Collectors;
|
24 | 28 | import java.util.stream.Stream;
|
25 | 29 |
|
26 | 30 | import org.junit.jupiter.api.Test;
|
| 31 | +import org.junit.jupiter.params.ParameterizedTest; |
| 32 | +import org.junit.jupiter.params.provider.MethodSource; |
27 | 33 |
|
28 | 34 | import org.springframework.security.oauth2.core.AuthenticationMethod;
|
29 | 35 | import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
30 | 36 | import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
31 | 37 |
|
32 | 38 | import static org.assertj.core.api.Assertions.assertThat;
|
33 | 39 | import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
| 40 | +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; |
34 | 41 |
|
35 | 42 | /**
|
36 | 43 | * Tests for {@link ClientRegistration}.
|
@@ -776,4 +783,59 @@ void buildWhenDefaultClientSettingsThenDefaulted() {
|
776 | 783 | assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isFalse();
|
777 | 784 | }
|
778 | 785 |
|
| 786 | + // gh-16382 |
| 787 | + @Test |
| 788 | + void buildWhenNewAuthorizationCodeAndPkceThenBuilds() { |
| 789 | + ClientSettings pkceEnabled = ClientSettings.builder().requireProofKey(true).build(); |
| 790 | + ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) |
| 791 | + .clientId(CLIENT_ID) |
| 792 | + .clientSettings(pkceEnabled) |
| 793 | + .authorizationGrantType(new AuthorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())) |
| 794 | + .redirectUri(REDIRECT_URI) |
| 795 | + .authorizationUri(AUTHORIZATION_URI) |
| 796 | + .tokenUri(TOKEN_URI) |
| 797 | + .build(); |
| 798 | + |
| 799 | + // proof key should be false for passivity |
| 800 | + assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isTrue(); |
| 801 | + } |
| 802 | + |
| 803 | + @ParameterizedTest |
| 804 | + @MethodSource("invalidPkceGrantTypes") |
| 805 | + void buildWhenInvalidGrantTypeForPkceThenException(AuthorizationGrantType invalidGrantType) { |
| 806 | + ClientSettings pkceEnabled = ClientSettings.builder().requireProofKey(true).build(); |
| 807 | + ClientRegistration.Builder builder = ClientRegistration.withRegistrationId(REGISTRATION_ID) |
| 808 | + .clientId(CLIENT_ID) |
| 809 | + .clientSettings(pkceEnabled) |
| 810 | + .authorizationGrantType(invalidGrantType) |
| 811 | + .redirectUri(REDIRECT_URI) |
| 812 | + .authorizationUri(AUTHORIZATION_URI) |
| 813 | + .tokenUri(TOKEN_URI); |
| 814 | + |
| 815 | + assertThatIllegalStateException().describedAs( |
| 816 | + "clientSettings.isRequireProofKey=true is only valid with authorizationGrantType=AUTHORIZATION_CODE. Got authorizationGrantType={}", |
| 817 | + invalidGrantType) |
| 818 | + .isThrownBy(builder::build); |
| 819 | + } |
| 820 | + |
| 821 | + static List<AuthorizationGrantType> invalidPkceGrantTypes() { |
| 822 | + return Arrays.stream(AuthorizationGrantType.class.getFields()) |
| 823 | + .filter((field) -> Modifier.isFinal(field.getModifiers()) |
| 824 | + && field.getType() == AuthorizationGrantType.class) |
| 825 | + .map((field) -> getStaticValue(field, AuthorizationGrantType.class)) |
| 826 | + .filter((grantType) -> grantType != AuthorizationGrantType.AUTHORIZATION_CODE) |
| 827 | + // ensure works with .equals |
| 828 | + .map((grantType) -> new AuthorizationGrantType(grantType.getValue())) |
| 829 | + .collect(Collectors.toList()); |
| 830 | + } |
| 831 | + |
| 832 | + private static <T> T getStaticValue(Field field, Class<T> clazz) { |
| 833 | + try { |
| 834 | + return (T) field.get(null); |
| 835 | + } |
| 836 | + catch (IllegalAccessException ex) { |
| 837 | + throw new RuntimeException(ex); |
| 838 | + } |
| 839 | + } |
| 840 | + |
779 | 841 | }
|
0 commit comments