diff --git a/core/src/main/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManager.java b/core/src/main/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManager.java index 0a79e6b2446..b556ce222ea 100644 --- a/core/src/main/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManager.java +++ b/core/src/main/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManager.java @@ -16,16 +16,17 @@ package org.springframework.security.authentication; -import org.springframework.security.core.Authentication; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import org.springframework.security.core.Authentication; import org.springframework.security.core.userdetails.ReactiveUserDetailsPasswordService; import org.springframework.security.core.userdetails.ReactiveUserDetailsService; +import org.springframework.security.core.userdetails.UserDetailsChecker; import org.springframework.security.crypto.factory.PasswordEncoderFactories; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Scheduler; -import reactor.core.scheduler.Schedulers; /** * A {@link ReactiveAuthenticationManager} that uses a {@link ReactiveUserDetailsService} to validate the provided @@ -43,6 +44,8 @@ public class UserDetailsRepositoryReactiveAuthenticationManager implements React private Scheduler scheduler = Schedulers.parallel(); + private UserDetailsChecker postAuthenticationChecks = userDetails -> {}; + public UserDetailsRepositoryReactiveAuthenticationManager(ReactiveUserDetailsService userDetailsService) { Assert.notNull(userDetailsService, "userDetailsService cannot be null"); this.userDetailsService = userDetailsService; @@ -65,6 +68,7 @@ public Mono authenticate(Authentication authentication) { } return Mono.just(u); }) + .doOnNext(this.postAuthenticationChecks::check) .map(u -> new UsernamePasswordAuthenticationToken(u, u.getPassword(), u.getAuthorities()) ); } @@ -102,4 +106,16 @@ public void setUserDetailsPasswordService( ReactiveUserDetailsPasswordService userDetailsPasswordService) { this.userDetailsPasswordService = userDetailsPasswordService; } + + /** + * Sets the strategy which will be used to validate the loaded UserDetails + * object after authentication occurs. + * + * @param postAuthenticationChecks The {@link UserDetailsChecker} + * @since 5.2 + */ + public void setPostAuthenticationChecks(UserDetailsChecker postAuthenticationChecks) { + Assert.notNull(this.postAuthenticationChecks, "postAuthenticationChecks cannot be null"); + this.postAuthenticationChecks = postAuthenticationChecks; + } } diff --git a/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java b/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java index 2037bfef6e9..c5c41726480 100644 --- a/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java @@ -16,27 +16,26 @@ package org.springframework.security.authentication; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + import org.springframework.security.core.Authentication; import org.springframework.security.core.userdetails.ReactiveUserDetailsPasswordService; import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UserDetailsChecker; import org.springframework.security.crypto.password.PasswordEncoder; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Scheduler; -import reactor.core.scheduler.Schedulers; - -import static org.assertj.core.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -56,6 +55,9 @@ public class UserDetailsRepositoryReactiveAuthenticationManagerTests { @Mock private Scheduler scheduler; + @Mock + private UserDetailsChecker postAuthenticationChecks; + private UserDetails user = User.withUsername("user") .password("password") .roles("USER") @@ -140,4 +142,33 @@ public void authenticateWhenPasswordServiceAndUpgradeFalseThenNotUpdated() { verifyZeroInteractions(this.userDetailsPasswordService); } + + @Test + public void authenticateWhenPostAuthenticationChecksFail() { + when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(this.user)); + doThrow(new LockedException("account is locked")).when(this.postAuthenticationChecks).check(any()); + when(this.encoder.matches(any(), any())).thenReturn(true); + this.manager.setPasswordEncoder(this.encoder); + this.manager.setPostAuthenticationChecks(this.postAuthenticationChecks); + + assertThatExceptionOfType(LockedException.class) + .isThrownBy(() -> this.manager.authenticate(new UsernamePasswordAuthenticationToken(this.user, this.user.getPassword())).block()) + .withMessage("account is locked"); + + verify(this.postAuthenticationChecks).check(eq(this.user)); + } + + @Test + public void authenticateWhenPostAuthenticationChecksNotSet() { + when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(this.user)); + when(this.encoder.matches(any(), any())).thenReturn(true); + this.manager.setPasswordEncoder(this.encoder); + + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( + this.user, this.user.getPassword()); + + this.manager.authenticate(token).block(); + + verifyZeroInteractions(this.postAuthenticationChecks); + } }