Skip to content

Add support JdbcOneTimeTokenService #15842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.aot.hint;

import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.security.authentication.ott.OneTimeToken;
import org.springframework.security.authentication.ott.OneTimeTokenService;

/**
*
* A JDBC implementation of an {@link OneTimeTokenService} that uses a
* {@link JdbcOperations} for {@link OneTimeToken} persistence.
*
* @author Max Batischev
* @since 6.4
*/
class OneTimeTokenRuntimeHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
hints.resources().registerPattern("org/springframework/security/core/ott/jdbc/one-time-tokens-schema.sql");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.authentication.ott;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.sql.Types;
import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.function.Function;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.scheduling.support.CronTrigger;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/**
*
* A JDBC implementation of an {@link OneTimeTokenService} that uses a
* {@link JdbcOperations} for {@link OneTimeToken} persistence.
*
* <p>
* <b>NOTE:</b> This {@code JdbcOneTimeTokenService} depends on the table definition
* described in
* "classpath:org/springframework/security/core/ott/jdbc/one-time-tokens-schema.sql" and
* therefore MUST be defined in the database schema.
*
* @author Max Batischev
* @since 6.4
*/
public final class JdbcOneTimeTokenService implements OneTimeTokenService {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a token is never consumed, then it is never deleted even after it expires. We probably need a way to ensure that expired tokens are cleaned up. Likely this would need to be done using something like a TaskScheduler. For an example of what it might look like, you can refer to JdbcIndexedSessionRepository in Spring Session.


private final Log logger = LogFactory.getLog(getClass());

private final JdbcOperations jdbcOperations;

private Function<OneTimeToken, List<SqlParameterValue>> oneTimeTokenParametersMapper = new OneTimeTokenParametersMapper();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we have a need for it, let's not make this a pluggable strategy or expose it publicly. We can always do it in another iteration.


private RowMapper<OneTimeToken> oneTimeTokenRowMapper = new OneTimeTokenRowMapper();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we have a need for it, let's not make this a pluggable strategy or expose it publicly. We can always do it in another iteration.


private Clock clock = Clock.systemUTC();

private ThreadPoolTaskScheduler taskScheduler;

private static final String DEFAULT_CLEANUP_CRON = "0 * * * * *";

private static final String TABLE_NAME = "one_time_tokens";

// @formatter:off
private static final String COLUMN_NAMES = "token_value, "
+ "username, "
+ "expires_at";
// @formatter:on

// @formatter:off
private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME
+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?)";
// @formatter:on

private static final String FILTER = "token_value = ?";

private static final String DELETE_ONE_TIME_TOKEN_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + FILTER;

// @formatter:off
private static final String SELECT_ONE_TIME_TOKEN_SQL = "SELECT " + COLUMN_NAMES
+ " FROM " + TABLE_NAME
+ " WHERE " + FILTER;
// @formatter:on

// @formatter:off
private static final String DELETE_SESSIONS_BY_EXPIRY_TIME_QUERY = "DELETE FROM "
+ TABLE_NAME
+ " WHERE expires_at < ?";
// @formatter:on

/**
* Constructs a {@code JdbcOneTimeTokenService} using the provide parameters.
* @param jdbcOperations the JDBC operations
* @param cleanupCron cleanup cron expression
*/
public JdbcOneTimeTokenService(JdbcOperations jdbcOperations, String cleanupCron) {
Assert.isTrue(StringUtils.hasText(cleanupCron), "cleanupCron cannot be null orr empty");
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
this.jdbcOperations = jdbcOperations;
this.taskScheduler = createTaskScheduler(cleanupCron);
}

/**
* Constructs a {@code JdbcOneTimeTokenService} using the provide parameters.
* @param jdbcOperations the JDBC operations
*/
public JdbcOneTimeTokenService(JdbcOperations jdbcOperations) {
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
this.jdbcOperations = jdbcOperations;
this.taskScheduler = createTaskScheduler(DEFAULT_CLEANUP_CRON);
}

@Override
public OneTimeToken generate(GenerateOneTimeTokenRequest request) {
Assert.notNull(request, "generateOneTimeTokenRequest cannot be null");
String token = UUID.randomUUID().toString();
Instant fiveMinutesFromNow = this.clock.instant().plusSeconds(300);
OneTimeToken oneTimeToken = new DefaultOneTimeToken(token, request.getUsername(), fiveMinutesFromNow);
insertOneTimeToken(oneTimeToken);
return oneTimeToken;
}

private void insertOneTimeToken(OneTimeToken oneTimeToken) {
List<SqlParameterValue> parameters = this.oneTimeTokenParametersMapper.apply(oneTimeToken);
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss);
}

@Override
public OneTimeToken consume(OneTimeTokenAuthenticationToken authenticationToken) {
Assert.notNull(authenticationToken, "authenticationToken cannot be null");

List<OneTimeToken> tokens = selectOneTimeToken(authenticationToken);
if (CollectionUtils.isEmpty(tokens)) {
return null;
}
OneTimeToken token = tokens.get(0);
deleteOneTimeToken(token);
if (isExpired(token)) {
return null;
}
return token;
}

private boolean isExpired(OneTimeToken ott) {
return this.clock.instant().isAfter(ott.getExpiresAt());
}

private List<OneTimeToken> selectOneTimeToken(OneTimeTokenAuthenticationToken authenticationToken) {
List<SqlParameterValue> parameters = List
.of(new SqlParameterValue(Types.VARCHAR, authenticationToken.getTokenValue()));
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
return this.jdbcOperations.query(SELECT_ONE_TIME_TOKEN_SQL, pss, this.oneTimeTokenRowMapper);
}

private void deleteOneTimeToken(OneTimeToken oneTimeToken) {
List<SqlParameterValue> parameters = List
.of(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getTokenValue()));
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss);
}

private ThreadPoolTaskScheduler createTaskScheduler(String cleanupCron) {
ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler();
taskScheduler.setThreadNamePrefix("spring-one-time-tokens-");
taskScheduler.initialize();
taskScheduler.schedule(this::cleanUpExpiredTokens, new CronTrigger(cleanupCron));
return taskScheduler;
}

public void cleanUpExpiredTokens() {
List<SqlParameterValue> parameters = List.of(new SqlParameterValue(Types.TIMESTAMP, Instant.now()));
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
int deletedCount = this.jdbcOperations.update(DELETE_SESSIONS_BY_EXPIRY_TIME_QUERY, pss);
this.logger.debug("Cleaned up " + deletedCount + " expired tokens");
}

/**
* Sets the {@link Clock} used when generating one-time token and checking token
* expiry.
* @param clock the clock
*/
public void setClock(Clock clock) {
Assert.notNull(clock, "clock cannot be null");
this.clock = clock;
}

/**
* The default {@code Function} that maps {@link OneTimeToken} to a {@code List} of
* {@link SqlParameterValue}.
*
* @author Max Batischev
* @since 6.4
*/
public static class OneTimeTokenParametersMapper implements Function<OneTimeToken, List<SqlParameterValue>> {

@Override
public List<SqlParameterValue> apply(OneTimeToken oneTimeToken) {
List<SqlParameterValue> parameters = new ArrayList<>();
parameters.add(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getTokenValue()));
parameters.add(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getUsername()));
parameters.add(new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(oneTimeToken.getExpiresAt())));
return parameters;
}

}

/**
* The default {@link RowMapper} that maps the current row in
* {@code java.sql.ResultSet} to {@link OneTimeToken}.
*
* @author Max Batischev
* @since 6.4
*/
public static class OneTimeTokenRowMapper implements RowMapper<OneTimeToken> {

@Override
public OneTimeToken mapRow(ResultSet rs, int rowNum) throws SQLException {
String tokenValue = rs.getString("token_value");
String userName = rs.getString("username");
Instant expiresAt = rs.getTimestamp("expires_at").toInstant();
return new DefaultOneTimeToken(tokenValue, userName, expiresAt);
}

}

}
4 changes: 3 additions & 1 deletion core/src/main/resources/META-INF/spring/aot.factories
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
org.springframework.aot.hint.RuntimeHintsRegistrar=\
org.springframework.security.aot.hint.CoreSecurityRuntimeHints
org.springframework.security.aot.hint.CoreSecurityRuntimeHints,\
org.springframework.security.aot.hint.OneTimeTokenRuntimeHints

org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor=\
org.springframework.security.aot.hint.SecurityHintsAotProcessor
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
create table one_time_tokens(
token_value varchar(36) not null primary key,
username varchar_ignorecase(50) not null,
expires_at timestamp not null
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.aot.hint;

import java.util.stream.Stream;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
import org.springframework.core.io.support.SpringFactoriesLoader;
import org.springframework.util.ClassUtils;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Tests for {@link OneTimeTokenRuntimeHints}
*
* @author Max Batischev
*/
class OneTimeTokenRuntimeHintsTests {

private final RuntimeHints hints = new RuntimeHints();

@BeforeEach
void setup() {
SpringFactoriesLoader.forResourceLocation("META-INF/spring/aot.factories")
.load(RuntimeHintsRegistrar.class)
.forEach((registrar) -> registrar.registerHints(this.hints, ClassUtils.getDefaultClassLoader()));
}

@ParameterizedTest
@MethodSource("getOneTimeTokensSqlFiles")
void oneTimeTokensSqlFilesHasHints(String schemaFile) {
assertThat(RuntimeHintsPredicates.resource().forResource(schemaFile)).accepts(this.hints);
}

private static Stream<String> getOneTimeTokensSqlFiles() {
return Stream.of("org/springframework/security/core/ott/jdbc/one-time-tokens-schema.sql");
}

}
Loading