diff --git a/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/JdbcOperationsSessionRepository.java b/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/JdbcOperationsSessionRepository.java index 4e86428bc..87528c961 100644 --- a/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/JdbcOperationsSessionRepository.java +++ b/spring-session-jdbc/src/main/java/org/springframework/session/jdbc/JdbcOperationsSessionRepository.java @@ -16,6 +16,7 @@ package org.springframework.session.jdbc; +import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; @@ -41,9 +42,11 @@ import org.springframework.core.serializer.support.SerializingConverter; import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.ConnectionCallback; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.ResultSetExtractor; +import org.springframework.jdbc.support.JdbcUtils; import org.springframework.jdbc.support.lob.DefaultLobHandler; import org.springframework.jdbc.support.lob.LobHandler; import org.springframework.session.DelegatingIndexResolver; @@ -154,6 +157,44 @@ public class JdbcOperationsSessionRepository + "WHERE SESSION_ID = ?"; // @formatter:on + /** + * MERGE is SQL standadrd. It is supported by SQL Server and Oracle, and eventually + * other databases. + */ + // @formatter:off + private static final String CREATE_SESSION_ATTRIBUTE_QUERY_MERGE = "MERGE INTO %TABLE_NAME%_ATTRIBUTES x " + + "USING ( " + + " SELECT PRIMARY_ID as SESSION_PRIMARY_ID, ? as ATTRIBUTE_NAME, ? as ATTRIBUTE_BYTES " + + " FROM %TABLE_NAME% " + + " WHERE SESSION_ID = ? " + + ") y " + + "ON (x.SESSION_PRIMARY_ID = y.SESSION_PRIMARY_ID and x.ATTRIBUTE_NAME=y.ATTRIBUTE_NAME) " + + "WHEN MATCHED THEN " + + " UPDATE SET ATTRIBUTE_BYTES = y.ATTRIBUTE_BYTES " + + "WHEN NOT MATCHED THEN " + + " INSERT(SESSION_PRIMARY_ID, ATTRIBUTE_NAME, ATTRIBUTE_BYTES) VALUES (y.SESSION_PRIMARY_ID, y.ATTRIBUTE_NAME, y.ATTRIBUTE_BYTES);"; + // @formatter:on + + /** + * ON DUPLICATE KEY UPDATE is MySQL/MariaDB specific. + */ + // @formatter:off + private static final String CREATE_SESSION_ATTRIBUTE_QUERY_ON_DUPLICATE_KEY_UPDATE = "INSERT INTO %TABLE_NAME%_ATTRIBUTES(SESSION_PRIMARY_ID, ATTRIBUTE_NAME, ATTRIBUTE_BYTES) " + + "SELECT PRIMARY_ID, ?, ? " + + "FROM %TABLE_NAME% " + + "WHERE SESSION_ID = ? ON DUPLICATE KEY UPDATE ATTRIBUTE_BYTES=VALUES(ATTRIBUTE_BYTES)"; + // @formatter:on + + /** + * ON CONFLICT is PostgreSQL specific. + */ + // @formatter:off + private static final String CREATE_SESSION_ATTRIBUTE_QUERY_ON_CONFLICT = "INSERT INTO %TABLE_NAME%_ATTRIBUTES(SESSION_PRIMARY_ID, ATTRIBUTE_NAME, ATTRIBUTE_BYTES) " + + "SELECT PRIMARY_ID, ?, ? " + + "FROM %TABLE_NAME% " + + "WHERE SESSION_ID = ? ON CONFLICT(SESSION_PRIMARY_ID, ATTRIBUTE_NAME) DO UPDATE SET ATTRIBUTE_BYTES=EXCLUDED.ATTRIBUTE_BYTES"; + // @formatter:on + // @formatter:off private static final String GET_SESSION_QUERY = "SELECT S.PRIMARY_ID, S.SESSION_ID, S.CREATION_TIME, S.LAST_ACCESS_TIME, S.MAX_INACTIVE_INTERVAL, SA.ATTRIBUTE_NAME, SA.ATTRIBUTE_BYTES " + "FROM %TABLE_NAME% S " @@ -203,6 +244,8 @@ public class JdbcOperationsSessionRepository private final IndexResolver indexResolver; + private final String commonDatabaseName; + private TransactionOperations transactionOperations = TransactionOperations.withoutTransaction(); /** @@ -268,6 +311,7 @@ public JdbcOperationsSessionRepository(JdbcOperations jdbcOperations) { this.jdbcOperations = jdbcOperations; this.indexResolver = new DelegatingIndexResolver<>(new PrincipalNameIndexResolver<>()); this.conversionService = createDefaultConversionService(); + this.commonDatabaseName = getCommonDatabaseName(); prepareQueries(); } @@ -650,8 +694,29 @@ private String getQuery(String base) { } private void prepareQueries() { + final String createSessionAttributeQuery; + switch ((this.commonDatabaseName != null) ? this.commonDatabaseName : "Unknown") { + case "MySQL": + createSessionAttributeQuery = CREATE_SESSION_ATTRIBUTE_QUERY_ON_DUPLICATE_KEY_UPDATE; + break; + case "PostgreSQL": + createSessionAttributeQuery = CREATE_SESSION_ATTRIBUTE_QUERY_ON_CONFLICT; + break; + case "DB2": + case "Microsoft SQL Server": + case "Oracle": + createSessionAttributeQuery = CREATE_SESSION_ATTRIBUTE_QUERY_MERGE; + break; + default: + if (logger.isDebugEnabled()) { + logger.warn("Using default create session attribute query because the database's common name, \"" + + commonDatabaseName + "\", is not known"); + } + createSessionAttributeQuery = CREATE_SESSION_ATTRIBUTE_QUERY; + + } this.createSessionQuery = getQuery(CREATE_SESSION_QUERY); - this.createSessionAttributeQuery = getQuery(CREATE_SESSION_ATTRIBUTE_QUERY); + this.createSessionAttributeQuery = getQuery(createSessionAttributeQuery); this.getSessionQuery = getQuery(GET_SESSION_QUERY); this.updateSessionQuery = getQuery(UPDATE_SESSION_QUERY); this.updateSessionAttributeQuery = getQuery(UPDATE_SESSION_ATTRIBUTE_QUERY); @@ -661,6 +726,23 @@ private void prepareQueries() { this.deleteSessionsByExpiryTimeQuery = getQuery(DELETE_SESSIONS_BY_EXPIRY_TIME_QUERY); } + private String getCommonDatabaseName() { + try { + return this.jdbcOperations.execute(new ConnectionCallback() { + @Override + public String doInConnection(final Connection connection) throws SQLException, DataAccessException { + return JdbcUtils.commonDatabaseName(connection.getMetaData().getDatabaseProductName()); + } + }); + } + catch (Exception ex) { + if (logger.isWarnEnabled()) { + logger.warn("Unable to determine database implementation common name", ex); + } + return null; + } + } + private LobHandler getLobHandler() { return this.lobHandler; } diff --git a/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/JdbcOperationsSessionRepositoryTests.java b/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/JdbcOperationsSessionRepositoryTests.java index c616cef42..c1983db33 100644 --- a/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/JdbcOperationsSessionRepositoryTests.java +++ b/spring-session-jdbc/src/test/java/org/springframework/session/jdbc/JdbcOperationsSessionRepositoryTests.java @@ -30,6 +30,7 @@ import org.mockito.ArgumentCaptor; import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.ConnectionCallback; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.PreparedStatementSetter; import org.springframework.jdbc.core.ResultSetExtractor; @@ -70,6 +71,8 @@ class JdbcOperationsSessionRepositoryTests { private static final String SPRING_SECURITY_CONTEXT = "SPRING_SECURITY_CONTEXT"; + private static final String MOCK_DATABASE_COMMON_NAME = "Mock Database"; + private JdbcOperations jdbcOperations = mock(JdbcOperations.class); private PlatformTransactionManager transactionManager = mock(PlatformTransactionManager.class); @@ -79,6 +82,7 @@ class JdbcOperationsSessionRepositoryTests { @BeforeEach void setUp() { this.repository = new JdbcOperationsSessionRepository(this.jdbcOperations, this.transactionManager); + given(this.jdbcOperations.execute(isA(ConnectionCallback.class))).willReturn(MOCK_DATABASE_COMMON_NAME); } @Test @@ -239,6 +243,7 @@ void createSessionDefaultMaxInactiveInterval() { assertThat(session.isNew()).isTrue(); assertThat(session.getMaxInactiveInterval()).isEqualTo(new MapSession().getMaxInactiveInterval()); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -251,6 +256,7 @@ void createSessionCustomMaxInactiveInterval() { assertThat(session.isNew()).isTrue(); assertThat(session.getMaxInactiveInterval()).isEqualTo(Duration.ofSeconds(interval)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -263,6 +269,7 @@ void saveNewWithoutAttributes() { assertThat(session.isNew()).isFalse(); assertPropagationRequiresNew(); verify(this.jdbcOperations, times(1)).update(startsWith("INSERT"), isA(PreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -279,6 +286,7 @@ void saveNewWithSingleAttribute() { isA(PreparedStatementSetter.class)); verify(this.jdbcOperations, times(1)).update(startsWith("INSERT INTO SPRING_SESSION_ATTRIBUTES("), isA(PreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -296,6 +304,7 @@ void saveNewWithMultipleAttributes() { isA(PreparedStatementSetter.class)); verify(this.jdbcOperations, times(1)).batchUpdate(startsWith("INSERT INTO SPRING_SESSION_ATTRIBUTES("), isA(BatchPreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -311,6 +320,7 @@ void saveUpdatedAddSingleAttribute() { assertPropagationRequiresNew(); verify(this.jdbcOperations, times(1)).update(startsWith("INSERT INTO SPRING_SESSION_ATTRIBUTES("), isA(PreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -327,6 +337,7 @@ void saveUpdatedAddMultipleAttributes() { assertPropagationRequiresNew(); verify(this.jdbcOperations, times(1)).batchUpdate(startsWith("INSERT INTO SPRING_SESSION_ATTRIBUTES("), isA(BatchPreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -344,6 +355,7 @@ void saveUpdatedModifySingleAttribute() { assertPropagationRequiresNew(); verify(this.jdbcOperations, times(1)).update(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"), isA(PreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -363,6 +375,7 @@ void saveUpdatedModifyMultipleAttributes() { assertPropagationRequiresNew(); verify(this.jdbcOperations, times(1)).batchUpdate(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"), isA(BatchPreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -380,6 +393,7 @@ void saveUpdatedRemoveSingleAttribute() { assertPropagationRequiresNew(); verify(this.jdbcOperations, times(1)).update(startsWith("DELETE FROM SPRING_SESSION_ATTRIBUTES WHERE"), isA(PreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -393,6 +407,7 @@ void saveUpdatedRemoveNonExistingAttribute() { assertThat(session.isNew()).isFalse(); assertPropagationRequiresNew(); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -412,6 +427,7 @@ void saveUpdatedRemoveMultipleAttributes() { assertPropagationRequiresNew(); verify(this.jdbcOperations, times(1)).batchUpdate(startsWith("DELETE FROM SPRING_SESSION_ATTRIBUTES WHERE"), isA(BatchPreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -428,6 +444,7 @@ void saveUpdatedAddAndModifyAttribute() { assertPropagationRequiresNew(); verify(this.jdbcOperations).update(startsWith("INSERT INTO SPRING_SESSION_ATTRIBUTES("), isA(PreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -442,6 +459,7 @@ void saveUpdatedAddAndRemoveAttribute() { assertThat(session.isNew()).isFalse(); assertPropagationRequiresNew(); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -460,6 +478,7 @@ void saveUpdatedModifyAndRemoveAttribute() { assertPropagationRequiresNew(); verify(this.jdbcOperations).update(startsWith("DELETE FROM SPRING_SESSION_ATTRIBUTES WHERE"), isA(PreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -478,6 +497,7 @@ void saveUpdatedRemoveAndAddAttribute() { assertPropagationRequiresNew(); verify(this.jdbcOperations).update(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"), isA(PreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -493,6 +513,7 @@ void saveUpdatedLastAccessedTime() { assertPropagationRequiresNew(); verify(this.jdbcOperations, times(1)).update(startsWith("UPDATE SPRING_SESSION SET"), isA(PreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -504,6 +525,7 @@ void saveUnchanged() { this.repository.save(session); assertThat(session.isNew()).isFalse(); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -575,6 +597,7 @@ void findByIndexNameAndIndexValueUnknownIndexName() { .findByIndexNameAndIndexValue("testIndexName", indexValue); assertThat(sessions).isEmpty(); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -649,6 +672,8 @@ void saveNewWithoutTransaction() { verify(this.jdbcOperations, times(1)).update(startsWith("INSERT INTO SPRING_SESSION"), isA(PreparedStatementSetter.class)); + // 2 times because the JdbcOperationsSessionRepository is invoked twice + verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); verifyZeroInteractions(this.transactionManager); } @@ -664,6 +689,8 @@ void saveUpdatedWithoutTransaction() { verify(this.jdbcOperations, times(1)).update(startsWith("UPDATE SPRING_SESSION"), isA(PreparedStatementSetter.class)); + // 2 times because the JdbcOperationsSessionRepository is invoked twice + verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); verifyZeroInteractions(this.transactionManager); } @@ -678,6 +705,8 @@ void findByIdWithoutTransaction() { verify(this.jdbcOperations, times(1)).query(endsWith("WHERE S.SESSION_ID = ?"), isA(PreparedStatementSetter.class), isA(ResultSetExtractor.class)); + // 2 times because the JdbcOperationsSessionRepository is invoked twice + verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); verifyZeroInteractions(this.transactionManager); } @@ -689,6 +718,8 @@ void deleteByIdWithoutTransaction() { verify(this.jdbcOperations, times(1)).update(eq("DELETE FROM SPRING_SESSION WHERE SESSION_ID = ?"), anyString()); + // 2 times because the JdbcOperationsSessionRepository is invoked twice + verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); verifyZeroInteractions(this.transactionManager); } @@ -704,6 +735,8 @@ void findByIndexNameAndIndexValueWithoutTransaction() { verify(this.jdbcOperations, times(1)).query(endsWith("WHERE S.PRINCIPAL_NAME = ?"), isA(PreparedStatementSetter.class), isA(ResultSetExtractor.class)); + // 2 times because the JdbcOperationsSessionRepository is invoked twice + verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); verifyZeroInteractions(this.transactionManager); } @@ -714,6 +747,8 @@ void cleanUpExpiredSessionsWithoutTransaction() { this.repository.cleanUpExpiredSessions(); verify(this.jdbcOperations, times(1)).update(eq("DELETE FROM SPRING_SESSION WHERE EXPIRY_TIME < ?"), anyLong()); + // 2 times because the JdbcOperationsSessionRepository is invoked twice + verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); verifyZeroInteractions(this.transactionManager); } @@ -732,6 +767,7 @@ void saveWithSaveModeOnSetAttribute() { this.repository.save(session); verify(this.jdbcOperations).update(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"), isA(PreparedStatementSetter.class)); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -751,6 +787,7 @@ void saveWithSaveModeOnGetAttribute() { .forClass(BatchPreparedStatementSetter.class); verify(this.jdbcOperations).batchUpdate(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"), captor.capture()); assertThat(captor.getValue().getBatchSize()).isEqualTo(2); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); } @@ -770,6 +807,7 @@ void saveWithSaveModeAlways() { .forClass(BatchPreparedStatementSetter.class); verify(this.jdbcOperations).batchUpdate(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"), captor.capture()); assertThat(captor.getValue().getBatchSize()).isEqualTo(3); + verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class)); verifyZeroInteractions(this.jdbcOperations); }