Skip to content

Use RDBMS specific queries updating session attributes #1481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.session.jdbc;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
Expand All @@ -41,9 +42,11 @@
import org.springframework.core.serializer.support.SerializingConverter;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.ConnectionCallback;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.support.JdbcUtils;
import org.springframework.jdbc.support.lob.DefaultLobHandler;
import org.springframework.jdbc.support.lob.LobHandler;
import org.springframework.session.DelegatingIndexResolver;
Expand Down Expand Up @@ -154,6 +157,44 @@ public class JdbcOperationsSessionRepository
+ "WHERE SESSION_ID = ?";
// @formatter:on

/**
* MERGE is SQL standadrd. It is supported by SQL Server and Oracle, and eventually
* other databases.
*/
// @formatter:off
private static final String CREATE_SESSION_ATTRIBUTE_QUERY_MERGE = "MERGE INTO %TABLE_NAME%_ATTRIBUTES x "
Copy link

@shark300 shark300 Aug 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got a "bad SQL grammar" exception on Oracle 18c if the SQL query has an ending semicolon.

(Yes, I know it's a declined PR, but we're using this as a workaround until the final solution.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this PR was declined, GitHub doesn't seem to be updating it when I push to the branch for it.

I had fixed this a while ago in my branch, so I suggest you grab it from there: https://github.com/candrews/spring-session/tree/attributes-duplicate-key

+ "USING ( "
+ " SELECT PRIMARY_ID as SESSION_PRIMARY_ID, ? as ATTRIBUTE_NAME, ? as ATTRIBUTE_BYTES "
+ " FROM %TABLE_NAME% "
+ " WHERE SESSION_ID = ? "
+ ") y "
+ "ON (x.SESSION_PRIMARY_ID = y.SESSION_PRIMARY_ID and x.ATTRIBUTE_NAME=y.ATTRIBUTE_NAME) "
+ "WHEN MATCHED THEN "
+ " UPDATE SET ATTRIBUTE_BYTES = y.ATTRIBUTE_BYTES "
+ "WHEN NOT MATCHED THEN "
+ " INSERT(SESSION_PRIMARY_ID, ATTRIBUTE_NAME, ATTRIBUTE_BYTES) VALUES (y.SESSION_PRIMARY_ID, y.ATTRIBUTE_NAME, y.ATTRIBUTE_BYTES);";
// @formatter:on

/**
* ON DUPLICATE KEY UPDATE is MySQL/MariaDB specific.
*/
// @formatter:off
private static final String CREATE_SESSION_ATTRIBUTE_QUERY_ON_DUPLICATE_KEY_UPDATE = "INSERT INTO %TABLE_NAME%_ATTRIBUTES(SESSION_PRIMARY_ID, ATTRIBUTE_NAME, ATTRIBUTE_BYTES) "
+ "SELECT PRIMARY_ID, ?, ? "
+ "FROM %TABLE_NAME% "
+ "WHERE SESSION_ID = ? ON DUPLICATE KEY UPDATE ATTRIBUTE_BYTES=VALUES(ATTRIBUTE_BYTES)";
// @formatter:on

/**
* ON CONFLICT is PostgreSQL specific.
*/
// @formatter:off
private static final String CREATE_SESSION_ATTRIBUTE_QUERY_ON_CONFLICT = "INSERT INTO %TABLE_NAME%_ATTRIBUTES(SESSION_PRIMARY_ID, ATTRIBUTE_NAME, ATTRIBUTE_BYTES) "
+ "SELECT PRIMARY_ID, ?, ? "
+ "FROM %TABLE_NAME% "
+ "WHERE SESSION_ID = ? ON CONFLICT(SESSION_PRIMARY_ID, ATTRIBUTE_NAME) DO UPDATE SET ATTRIBUTE_BYTES=EXCLUDED.ATTRIBUTE_BYTES";
// @formatter:on

// @formatter:off
private static final String GET_SESSION_QUERY = "SELECT S.PRIMARY_ID, S.SESSION_ID, S.CREATION_TIME, S.LAST_ACCESS_TIME, S.MAX_INACTIVE_INTERVAL, SA.ATTRIBUTE_NAME, SA.ATTRIBUTE_BYTES "
+ "FROM %TABLE_NAME% S "
Expand Down Expand Up @@ -203,6 +244,8 @@ public class JdbcOperationsSessionRepository

private final IndexResolver<JdbcSession> indexResolver;

private final String commonDatabaseName;

private TransactionOperations transactionOperations = TransactionOperations.withoutTransaction();

/**
Expand Down Expand Up @@ -268,6 +311,7 @@ public JdbcOperationsSessionRepository(JdbcOperations jdbcOperations) {
this.jdbcOperations = jdbcOperations;
this.indexResolver = new DelegatingIndexResolver<>(new PrincipalNameIndexResolver<>());
this.conversionService = createDefaultConversionService();
this.commonDatabaseName = getCommonDatabaseName();
prepareQueries();
}

Expand Down Expand Up @@ -650,8 +694,29 @@ private String getQuery(String base) {
}

private void prepareQueries() {
final String createSessionAttributeQuery;
switch ((this.commonDatabaseName != null) ? this.commonDatabaseName : "Unknown") {
case "MySQL":
createSessionAttributeQuery = CREATE_SESSION_ATTRIBUTE_QUERY_ON_DUPLICATE_KEY_UPDATE;
break;
case "PostgreSQL":
createSessionAttributeQuery = CREATE_SESSION_ATTRIBUTE_QUERY_ON_CONFLICT;
break;
case "DB2":
case "Microsoft SQL Server":
case "Oracle":
createSessionAttributeQuery = CREATE_SESSION_ATTRIBUTE_QUERY_MERGE;
break;
default:
if (logger.isDebugEnabled()) {
logger.warn("Using default create session attribute query because the database's common name, \""
+ commonDatabaseName + "\", is not known");
}
createSessionAttributeQuery = CREATE_SESSION_ATTRIBUTE_QUERY;

}
this.createSessionQuery = getQuery(CREATE_SESSION_QUERY);
this.createSessionAttributeQuery = getQuery(CREATE_SESSION_ATTRIBUTE_QUERY);
this.createSessionAttributeQuery = getQuery(createSessionAttributeQuery);
this.getSessionQuery = getQuery(GET_SESSION_QUERY);
this.updateSessionQuery = getQuery(UPDATE_SESSION_QUERY);
this.updateSessionAttributeQuery = getQuery(UPDATE_SESSION_ATTRIBUTE_QUERY);
Expand All @@ -661,6 +726,23 @@ private void prepareQueries() {
this.deleteSessionsByExpiryTimeQuery = getQuery(DELETE_SESSIONS_BY_EXPIRY_TIME_QUERY);
}

private String getCommonDatabaseName() {
try {
return this.jdbcOperations.execute(new ConnectionCallback<String>() {
@Override
public String doInConnection(final Connection connection) throws SQLException, DataAccessException {
return JdbcUtils.commonDatabaseName(connection.getMetaData().getDatabaseProductName());
}
});
}
catch (Exception ex) {
if (logger.isWarnEnabled()) {
logger.warn("Unable to determine database implementation common name", ex);
}
return null;
}
}

private LobHandler getLobHandler() {
return this.lobHandler;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.mockito.ArgumentCaptor;

import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.ConnectionCallback;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.ResultSetExtractor;
Expand Down Expand Up @@ -70,6 +71,8 @@ class JdbcOperationsSessionRepositoryTests {

private static final String SPRING_SECURITY_CONTEXT = "SPRING_SECURITY_CONTEXT";

private static final String MOCK_DATABASE_COMMON_NAME = "Mock Database";

private JdbcOperations jdbcOperations = mock(JdbcOperations.class);

private PlatformTransactionManager transactionManager = mock(PlatformTransactionManager.class);
Expand All @@ -79,6 +82,7 @@ class JdbcOperationsSessionRepositoryTests {
@BeforeEach
void setUp() {
this.repository = new JdbcOperationsSessionRepository(this.jdbcOperations, this.transactionManager);
given(this.jdbcOperations.execute(isA(ConnectionCallback.class))).willReturn(MOCK_DATABASE_COMMON_NAME);
}

@Test
Expand Down Expand Up @@ -239,6 +243,7 @@ void createSessionDefaultMaxInactiveInterval() {

assertThat(session.isNew()).isTrue();
assertThat(session.getMaxInactiveInterval()).isEqualTo(new MapSession().getMaxInactiveInterval());
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -251,6 +256,7 @@ void createSessionCustomMaxInactiveInterval() {

assertThat(session.isNew()).isTrue();
assertThat(session.getMaxInactiveInterval()).isEqualTo(Duration.ofSeconds(interval));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -263,6 +269,7 @@ void saveNewWithoutAttributes() {
assertThat(session.isNew()).isFalse();
assertPropagationRequiresNew();
verify(this.jdbcOperations, times(1)).update(startsWith("INSERT"), isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -279,6 +286,7 @@ void saveNewWithSingleAttribute() {
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).update(startsWith("INSERT INTO SPRING_SESSION_ATTRIBUTES("),
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -296,6 +304,7 @@ void saveNewWithMultipleAttributes() {
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).batchUpdate(startsWith("INSERT INTO SPRING_SESSION_ATTRIBUTES("),
isA(BatchPreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -311,6 +320,7 @@ void saveUpdatedAddSingleAttribute() {
assertPropagationRequiresNew();
verify(this.jdbcOperations, times(1)).update(startsWith("INSERT INTO SPRING_SESSION_ATTRIBUTES("),
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -327,6 +337,7 @@ void saveUpdatedAddMultipleAttributes() {
assertPropagationRequiresNew();
verify(this.jdbcOperations, times(1)).batchUpdate(startsWith("INSERT INTO SPRING_SESSION_ATTRIBUTES("),
isA(BatchPreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -344,6 +355,7 @@ void saveUpdatedModifySingleAttribute() {
assertPropagationRequiresNew();
verify(this.jdbcOperations, times(1)).update(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"),
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -363,6 +375,7 @@ void saveUpdatedModifyMultipleAttributes() {
assertPropagationRequiresNew();
verify(this.jdbcOperations, times(1)).batchUpdate(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"),
isA(BatchPreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -380,6 +393,7 @@ void saveUpdatedRemoveSingleAttribute() {
assertPropagationRequiresNew();
verify(this.jdbcOperations, times(1)).update(startsWith("DELETE FROM SPRING_SESSION_ATTRIBUTES WHERE"),
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -393,6 +407,7 @@ void saveUpdatedRemoveNonExistingAttribute() {

assertThat(session.isNew()).isFalse();
assertPropagationRequiresNew();
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -412,6 +427,7 @@ void saveUpdatedRemoveMultipleAttributes() {
assertPropagationRequiresNew();
verify(this.jdbcOperations, times(1)).batchUpdate(startsWith("DELETE FROM SPRING_SESSION_ATTRIBUTES WHERE"),
isA(BatchPreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -428,6 +444,7 @@ void saveUpdatedAddAndModifyAttribute() {
assertPropagationRequiresNew();
verify(this.jdbcOperations).update(startsWith("INSERT INTO SPRING_SESSION_ATTRIBUTES("),
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -442,6 +459,7 @@ void saveUpdatedAddAndRemoveAttribute() {

assertThat(session.isNew()).isFalse();
assertPropagationRequiresNew();
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -460,6 +478,7 @@ void saveUpdatedModifyAndRemoveAttribute() {
assertPropagationRequiresNew();
verify(this.jdbcOperations).update(startsWith("DELETE FROM SPRING_SESSION_ATTRIBUTES WHERE"),
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -478,6 +497,7 @@ void saveUpdatedRemoveAndAddAttribute() {
assertPropagationRequiresNew();
verify(this.jdbcOperations).update(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"),
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -493,6 +513,7 @@ void saveUpdatedLastAccessedTime() {
assertPropagationRequiresNew();
verify(this.jdbcOperations, times(1)).update(startsWith("UPDATE SPRING_SESSION SET"),
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -504,6 +525,7 @@ void saveUnchanged() {
this.repository.save(session);

assertThat(session.isNew()).isFalse();
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand Down Expand Up @@ -575,6 +597,7 @@ void findByIndexNameAndIndexValueUnknownIndexName() {
.findByIndexNameAndIndexValue("testIndexName", indexValue);

assertThat(sessions).isEmpty();
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand Down Expand Up @@ -649,6 +672,8 @@ void saveNewWithoutTransaction() {

verify(this.jdbcOperations, times(1)).update(startsWith("INSERT INTO SPRING_SESSION"),
isA(PreparedStatementSetter.class));
// 2 times because the JdbcOperationsSessionRepository is invoked twice
verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
verifyZeroInteractions(this.transactionManager);
}
Expand All @@ -664,6 +689,8 @@ void saveUpdatedWithoutTransaction() {

verify(this.jdbcOperations, times(1)).update(startsWith("UPDATE SPRING_SESSION"),
isA(PreparedStatementSetter.class));
// 2 times because the JdbcOperationsSessionRepository is invoked twice
verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
verifyZeroInteractions(this.transactionManager);
}
Expand All @@ -678,6 +705,8 @@ void findByIdWithoutTransaction() {

verify(this.jdbcOperations, times(1)).query(endsWith("WHERE S.SESSION_ID = ?"),
isA(PreparedStatementSetter.class), isA(ResultSetExtractor.class));
// 2 times because the JdbcOperationsSessionRepository is invoked twice
verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
verifyZeroInteractions(this.transactionManager);
}
Expand All @@ -689,6 +718,8 @@ void deleteByIdWithoutTransaction() {

verify(this.jdbcOperations, times(1)).update(eq("DELETE FROM SPRING_SESSION WHERE SESSION_ID = ?"),
anyString());
// 2 times because the JdbcOperationsSessionRepository is invoked twice
verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
verifyZeroInteractions(this.transactionManager);
}
Expand All @@ -704,6 +735,8 @@ void findByIndexNameAndIndexValueWithoutTransaction() {

verify(this.jdbcOperations, times(1)).query(endsWith("WHERE S.PRINCIPAL_NAME = ?"),
isA(PreparedStatementSetter.class), isA(ResultSetExtractor.class));
// 2 times because the JdbcOperationsSessionRepository is invoked twice
verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
verifyZeroInteractions(this.transactionManager);
}
Expand All @@ -714,6 +747,8 @@ void cleanUpExpiredSessionsWithoutTransaction() {
this.repository.cleanUpExpiredSessions();

verify(this.jdbcOperations, times(1)).update(eq("DELETE FROM SPRING_SESSION WHERE EXPIRY_TIME < ?"), anyLong());
// 2 times because the JdbcOperationsSessionRepository is invoked twice
verify(this.jdbcOperations, times(2)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
verifyZeroInteractions(this.transactionManager);
}
Expand All @@ -732,6 +767,7 @@ void saveWithSaveModeOnSetAttribute() {
this.repository.save(session);
verify(this.jdbcOperations).update(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"),
isA(PreparedStatementSetter.class));
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -751,6 +787,7 @@ void saveWithSaveModeOnGetAttribute() {
.forClass(BatchPreparedStatementSetter.class);
verify(this.jdbcOperations).batchUpdate(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"), captor.capture());
assertThat(captor.getValue().getBatchSize()).isEqualTo(2);
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand All @@ -770,6 +807,7 @@ void saveWithSaveModeAlways() {
.forClass(BatchPreparedStatementSetter.class);
verify(this.jdbcOperations).batchUpdate(startsWith("UPDATE SPRING_SESSION_ATTRIBUTES SET"), captor.capture());
assertThat(captor.getValue().getBatchSize()).isEqualTo(3);
verify(this.jdbcOperations, times(1)).execute(isA(ConnectionCallback.class));
verifyZeroInteractions(this.jdbcOperations);
}

Expand Down