diff --git a/core/src/main/java/org/springframework/security/aot/hint/OneTimeTokenRuntimeHints.java b/core/src/main/java/org/springframework/security/aot/hint/OneTimeTokenRuntimeHints.java new file mode 100644 index 00000000000..5dd7ddb3ef0 --- /dev/null +++ b/core/src/main/java/org/springframework/security/aot/hint/OneTimeTokenRuntimeHints.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.aot.hint; + +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.security.authentication.ott.OneTimeToken; +import org.springframework.security.authentication.ott.OneTimeTokenService; + +/** + * + * A JDBC implementation of an {@link OneTimeTokenService} that uses a + * {@link JdbcOperations} for {@link OneTimeToken} persistence. + * + * @author Max Batischev + * @since 6.4 + */ +class OneTimeTokenRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + hints.resources().registerPattern("org/springframework/security/core/ott/jdbc/one-time-tokens-schema.sql"); + } + +} diff --git a/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java b/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java new file mode 100644 index 00000000000..fe8b32e48dd --- /dev/null +++ b/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java @@ -0,0 +1,239 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.authentication.ott; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; +import java.time.Clock; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.function.Function; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; +import org.springframework.scheduling.support.CronTrigger; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * + * A JDBC implementation of an {@link OneTimeTokenService} that uses a + * {@link JdbcOperations} for {@link OneTimeToken} persistence. + * + *

+ * NOTE: This {@code JdbcOneTimeTokenService} depends on the table definition + * described in + * "classpath:org/springframework/security/core/ott/jdbc/one-time-tokens-schema.sql" and + * therefore MUST be defined in the database schema. + * + * @author Max Batischev + * @since 6.4 + */ +public final class JdbcOneTimeTokenService implements OneTimeTokenService { + + private final Log logger = LogFactory.getLog(getClass()); + + private final JdbcOperations jdbcOperations; + + private Function> oneTimeTokenParametersMapper = new OneTimeTokenParametersMapper(); + + private RowMapper oneTimeTokenRowMapper = new OneTimeTokenRowMapper(); + + private Clock clock = Clock.systemUTC(); + + private ThreadPoolTaskScheduler taskScheduler; + + private static final String DEFAULT_CLEANUP_CRON = "0 * * * * *"; + + private static final String TABLE_NAME = "one_time_tokens"; + + // @formatter:off + private static final String COLUMN_NAMES = "token_value, " + + "username, " + + "expires_at"; + // @formatter:on + + // @formatter:off + private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?)"; + // @formatter:on + + private static final String FILTER = "token_value = ?"; + + private static final String DELETE_ONE_TIME_TOKEN_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + FILTER; + + // @formatter:off + private static final String SELECT_ONE_TIME_TOKEN_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + + " WHERE " + FILTER; + // @formatter:on + + // @formatter:off + private static final String DELETE_SESSIONS_BY_EXPIRY_TIME_QUERY = "DELETE FROM " + + TABLE_NAME + + " WHERE expires_at < ?"; + // @formatter:on + + /** + * Constructs a {@code JdbcOneTimeTokenService} using the provide parameters. + * @param jdbcOperations the JDBC operations + * @param cleanupCron cleanup cron expression + */ + public JdbcOneTimeTokenService(JdbcOperations jdbcOperations, String cleanupCron) { + Assert.isTrue(StringUtils.hasText(cleanupCron), "cleanupCron cannot be null orr empty"); + Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); + this.jdbcOperations = jdbcOperations; + this.taskScheduler = createTaskScheduler(cleanupCron); + } + + /** + * Constructs a {@code JdbcOneTimeTokenService} using the provide parameters. + * @param jdbcOperations the JDBC operations + */ + public JdbcOneTimeTokenService(JdbcOperations jdbcOperations) { + Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); + this.jdbcOperations = jdbcOperations; + this.taskScheduler = createTaskScheduler(DEFAULT_CLEANUP_CRON); + } + + @Override + public OneTimeToken generate(GenerateOneTimeTokenRequest request) { + Assert.notNull(request, "generateOneTimeTokenRequest cannot be null"); + String token = UUID.randomUUID().toString(); + Instant fiveMinutesFromNow = this.clock.instant().plusSeconds(300); + OneTimeToken oneTimeToken = new DefaultOneTimeToken(token, request.getUsername(), fiveMinutesFromNow); + insertOneTimeToken(oneTimeToken); + return oneTimeToken; + } + + private void insertOneTimeToken(OneTimeToken oneTimeToken) { + List parameters = this.oneTimeTokenParametersMapper.apply(oneTimeToken); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss); + } + + @Override + public OneTimeToken consume(OneTimeTokenAuthenticationToken authenticationToken) { + Assert.notNull(authenticationToken, "authenticationToken cannot be null"); + + List tokens = selectOneTimeToken(authenticationToken); + if (CollectionUtils.isEmpty(tokens)) { + return null; + } + OneTimeToken token = tokens.get(0); + deleteOneTimeToken(token); + if (isExpired(token)) { + return null; + } + return token; + } + + private boolean isExpired(OneTimeToken ott) { + return this.clock.instant().isAfter(ott.getExpiresAt()); + } + + private List selectOneTimeToken(OneTimeTokenAuthenticationToken authenticationToken) { + List parameters = List + .of(new SqlParameterValue(Types.VARCHAR, authenticationToken.getTokenValue())); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + return this.jdbcOperations.query(SELECT_ONE_TIME_TOKEN_SQL, pss, this.oneTimeTokenRowMapper); + } + + private void deleteOneTimeToken(OneTimeToken oneTimeToken) { + List parameters = List + .of(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getTokenValue())); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss); + } + + private ThreadPoolTaskScheduler createTaskScheduler(String cleanupCron) { + ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler(); + taskScheduler.setThreadNamePrefix("spring-one-time-tokens-"); + taskScheduler.initialize(); + taskScheduler.schedule(this::cleanUpExpiredTokens, new CronTrigger(cleanupCron)); + return taskScheduler; + } + + public void cleanUpExpiredTokens() { + List parameters = List.of(new SqlParameterValue(Types.TIMESTAMP, Instant.now())); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + int deletedCount = this.jdbcOperations.update(DELETE_SESSIONS_BY_EXPIRY_TIME_QUERY, pss); + this.logger.debug("Cleaned up " + deletedCount + " expired tokens"); + } + + /** + * Sets the {@link Clock} used when generating one-time token and checking token + * expiry. + * @param clock the clock + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock cannot be null"); + this.clock = clock; + } + + /** + * The default {@code Function} that maps {@link OneTimeToken} to a {@code List} of + * {@link SqlParameterValue}. + * + * @author Max Batischev + * @since 6.4 + */ + public static class OneTimeTokenParametersMapper implements Function> { + + @Override + public List apply(OneTimeToken oneTimeToken) { + List parameters = new ArrayList<>(); + parameters.add(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getTokenValue())); + parameters.add(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getUsername())); + parameters.add(new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(oneTimeToken.getExpiresAt()))); + return parameters; + } + + } + + /** + * The default {@link RowMapper} that maps the current row in + * {@code java.sql.ResultSet} to {@link OneTimeToken}. + * + * @author Max Batischev + * @since 6.4 + */ + public static class OneTimeTokenRowMapper implements RowMapper { + + @Override + public OneTimeToken mapRow(ResultSet rs, int rowNum) throws SQLException { + String tokenValue = rs.getString("token_value"); + String userName = rs.getString("username"); + Instant expiresAt = rs.getTimestamp("expires_at").toInstant(); + return new DefaultOneTimeToken(tokenValue, userName, expiresAt); + } + + } + +} diff --git a/core/src/main/resources/META-INF/spring/aot.factories b/core/src/main/resources/META-INF/spring/aot.factories index 2a24e540732..8596dc6a3fe 100644 --- a/core/src/main/resources/META-INF/spring/aot.factories +++ b/core/src/main/resources/META-INF/spring/aot.factories @@ -1,4 +1,6 @@ org.springframework.aot.hint.RuntimeHintsRegistrar=\ -org.springframework.security.aot.hint.CoreSecurityRuntimeHints +org.springframework.security.aot.hint.CoreSecurityRuntimeHints,\ +org.springframework.security.aot.hint.OneTimeTokenRuntimeHints + org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor=\ org.springframework.security.aot.hint.SecurityHintsAotProcessor diff --git a/core/src/main/resources/org/springframework/security/core/ott/jdbc/one-time-tokens-schema.sql b/core/src/main/resources/org/springframework/security/core/ott/jdbc/one-time-tokens-schema.sql new file mode 100644 index 00000000000..2c471ee4042 --- /dev/null +++ b/core/src/main/resources/org/springframework/security/core/ott/jdbc/one-time-tokens-schema.sql @@ -0,0 +1,5 @@ +create table one_time_tokens( + token_value varchar(36) not null primary key, + username varchar_ignorecase(50) not null, + expires_at timestamp not null +); diff --git a/core/src/test/java/org/springframework/security/aot/hint/OneTimeTokenRuntimeHintsTests.java b/core/src/test/java/org/springframework/security/aot/hint/OneTimeTokenRuntimeHintsTests.java new file mode 100644 index 00000000000..d132e9f49af --- /dev/null +++ b/core/src/test/java/org/springframework/security/aot/hint/OneTimeTokenRuntimeHintsTests.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.aot.hint; + +import java.util.stream.Stream; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; +import org.springframework.core.io.support.SpringFactoriesLoader; +import org.springframework.util.ClassUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OneTimeTokenRuntimeHints} + * + * @author Max Batischev + */ +class OneTimeTokenRuntimeHintsTests { + + private final RuntimeHints hints = new RuntimeHints(); + + @BeforeEach + void setup() { + SpringFactoriesLoader.forResourceLocation("META-INF/spring/aot.factories") + .load(RuntimeHintsRegistrar.class) + .forEach((registrar) -> registrar.registerHints(this.hints, ClassUtils.getDefaultClassLoader())); + } + + @ParameterizedTest + @MethodSource("getOneTimeTokensSqlFiles") + void oneTimeTokensSqlFilesHasHints(String schemaFile) { + assertThat(RuntimeHintsPredicates.resource().forResource(schemaFile)).accepts(this.hints); + } + + private static Stream getOneTimeTokensSqlFiles() { + return Stream.of("org/springframework/security/core/ott/jdbc/one-time-tokens-schema.sql"); + } + +} diff --git a/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java b/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java new file mode 100644 index 00000000000..9bbbe32fb6e --- /dev/null +++ b/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java @@ -0,0 +1,196 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.authentication.ott; + +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.List; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import org.springframework.util.CollectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link JdbcOneTimeTokenService}. + * + * @author Max Batischev + */ +public class JdbcOneTimeTokenServiceTests { + + private static final String USERNAME = "user"; + + private static final String TOKEN_VALUE = "1234"; + + private static final String ONE_TIME_TOKEN_SQL_RESOURCE = "org/springframework/security/core/ott/jdbc/one-time-tokens-schema.sql"; + + private EmbeddedDatabase db; + + private JdbcOperations jdbcOperations; + + private JdbcOneTimeTokenService oneTimeTokenService; + + private final JdbcOneTimeTokenService.OneTimeTokenParametersMapper oneTimeTokenParametersMapper = new JdbcOneTimeTokenService.OneTimeTokenParametersMapper(); + + @BeforeEach + void setUp() { + this.db = createDb(); + this.jdbcOperations = new JdbcTemplate(this.db); + this.oneTimeTokenService = new JdbcOneTimeTokenService(this.jdbcOperations); + } + + @AfterEach + public void tearDown() { + this.db.shutdown(); + } + + private static EmbeddedDatabase createDb() { + // @formatter:off + return new EmbeddedDatabaseBuilder() + .generateUniqueName(true) + .setType(EmbeddedDatabaseType.HSQL) + .setScriptEncoding("UTF-8") + .addScript(ONE_TIME_TOKEN_SQL_RESOURCE) + .build(); + // @formatter:on + } + + @Test + void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new JdbcOneTimeTokenService(null)) + .withMessage("jdbcOperations cannot be null"); + // @formatter:on + } + + @Test + void generateWhenGenerateOneTimeTokenRequestIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.oneTimeTokenService.generate(null)) + .withMessage("generateOneTimeTokenRequest cannot be null"); + // @formatter:on + } + + @Test + void consumeWhenAuthenticationTokenIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.oneTimeTokenService.consume(null)) + .withMessage("authenticationToken cannot be null"); + // @formatter:on + } + + @Test + void generateThenTokenValueShouldBeValidUuidAndProvidedUsernameIsUsed() { + OneTimeToken oneTimeToken = this.oneTimeTokenService.generate(new GenerateOneTimeTokenRequest(USERNAME)); + + OneTimeToken persistedOneTimeToken = selectOneTimeToken(oneTimeToken.getTokenValue()); + assertThat(persistedOneTimeToken).isNotNull(); + assertThat(persistedOneTimeToken.getUsername()).isNotNull(); + assertThat(persistedOneTimeToken.getTokenValue()).isNotNull(); + assertThat(persistedOneTimeToken.getExpiresAt()).isNotNull(); + } + + @Test + void consumeWhenTokenExistsThenReturnItself() { + OneTimeToken oneTimeToken = this.oneTimeTokenService.generate(new GenerateOneTimeTokenRequest(USERNAME)); + OneTimeTokenAuthenticationToken authenticationToken = new OneTimeTokenAuthenticationToken( + oneTimeToken.getTokenValue()); + + OneTimeToken consumedOneTimeToken = this.oneTimeTokenService.consume(authenticationToken); + + assertThat(consumedOneTimeToken).isNotNull(); + assertThat(consumedOneTimeToken.getUsername()).isNotNull(); + assertThat(consumedOneTimeToken.getTokenValue()).isNotNull(); + assertThat(consumedOneTimeToken.getExpiresAt()).isNotNull(); + OneTimeToken persistedOneTimeToken = selectOneTimeToken(consumedOneTimeToken.getTokenValue()); + assertThat(persistedOneTimeToken).isNull(); + } + + @Test + void consumeWhenTokenDoesNotExistsThenReturnNull() { + OneTimeTokenAuthenticationToken authenticationToken = new OneTimeTokenAuthenticationToken(TOKEN_VALUE); + + OneTimeToken consumedOneTimeToken = this.oneTimeTokenService.consume(authenticationToken); + + assertThat(consumedOneTimeToken).isNull(); + } + + @Test + void consumeWhenTokenIsExpiredThenReturnNull() { + GenerateOneTimeTokenRequest request = new GenerateOneTimeTokenRequest(USERNAME); + OneTimeToken generated = this.oneTimeTokenService.generate(request); + OneTimeTokenAuthenticationToken authenticationToken = new OneTimeTokenAuthenticationToken( + generated.getTokenValue()); + Clock tenMinutesFromNow = Clock.fixed(Instant.now().plus(10, ChronoUnit.MINUTES), ZoneOffset.UTC); + this.oneTimeTokenService.setClock(tenMinutesFromNow); + + OneTimeToken consumed = this.oneTimeTokenService.consume(authenticationToken); + assertThat(consumed).isNull(); + } + + @Test + void cleanupExpiredTokens() { + OneTimeToken token1 = new DefaultOneTimeToken("123", USERNAME, Instant.now().minusSeconds(300)); + OneTimeToken token2 = new DefaultOneTimeToken("456", USERNAME, Instant.now().minusSeconds(300)); + saveToken(token1); + saveToken(token2); + + this.oneTimeTokenService.cleanUpExpiredTokens(); + + OneTimeToken deletedOneTimeToken1 = selectOneTimeToken("123"); + OneTimeToken deletedOneTimeToken2 = selectOneTimeToken("456"); + assertThat(deletedOneTimeToken1).isNull(); + assertThat(deletedOneTimeToken2).isNull(); + } + + private void saveToken(OneTimeToken oneTimeToken) { + List parameters = this.oneTimeTokenParametersMapper.apply(oneTimeToken); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + this.jdbcOperations.update("INSERT INTO one_time_tokens (token_value, username, expires_at) VALUES (?, ?, ?)", + pss); + } + + private OneTimeToken selectOneTimeToken(String tokenValue) { + // @formatter:off + List result = this.jdbcOperations.query( + "select token_value, username, expires_at from one_time_tokens where token_value = ?", + new JdbcOneTimeTokenService.OneTimeTokenRowMapper(), tokenValue); + if (CollectionUtils.isEmpty(result)) { + return null; + } + return result.get(0); + // @formatter:on + } + +}