Skip to content

JdbcOAuth2AuthorizationService supports clob and text datatype for token columns #491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization;

import java.nio.charset.StandardCharsets;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
Expand All @@ -35,6 +36,7 @@

import org.springframework.dao.DataRetrievalFailureException;
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
import org.springframework.jdbc.core.ConnectionCallback;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;
Expand Down Expand Up @@ -141,6 +143,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic

private final JdbcOperations jdbcOperations;
private final LobHandler lobHandler;
private static int tokenColumnType;
private RowMapper<OAuth2Authorization> authorizationRowMapper;
private Function<OAuth2Authorization, List<SqlParameterValue>> authorizationParametersMapper;

Expand Down Expand Up @@ -169,12 +172,15 @@ public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
Assert.notNull(lobHandler, "lobHandler cannot be null");
this.jdbcOperations = jdbcOperations;
this.lobHandler = lobHandler;
tokenColumnType = getColumnDataType(jdbcOperations, "access_token_value");
OAuth2AuthorizationRowMapper authorizationRowMapper = new OAuth2AuthorizationRowMapper(registeredClientRepository);
authorizationRowMapper.setLobHandler(lobHandler);
this.authorizationRowMapper = authorizationRowMapper;
this.authorizationParametersMapper = new OAuth2AuthorizationParametersMapper();
OAuth2AuthorizationParametersMapper authorizationParametersMapper = new OAuth2AuthorizationParametersMapper();
this.authorizationParametersMapper = authorizationParametersMapper;
}


@Override
public void save(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null");
Expand Down Expand Up @@ -232,26 +238,33 @@ public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType t
List<SqlParameterValue> parameters = new ArrayList<>();
if (tokenType == null) {
parameters.add(new SqlParameterValue(Types.VARCHAR, token));
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(mapTokenToSqlParameter(token));
parameters.add(mapTokenToSqlParameter(token));
parameters.add(mapTokenToSqlParameter(token));
return findBy(UNKNOWN_TOKEN_TYPE_FILTER, parameters);
} else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
parameters.add(new SqlParameterValue(Types.VARCHAR, token));
return findBy(STATE_FILTER, parameters);
} else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(mapTokenToSqlParameter(token));
return findBy(AUTHORIZATION_CODE_FILTER, parameters);
} else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) {
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(mapTokenToSqlParameter(token));
return findBy(ACCESS_TOKEN_FILTER, parameters);
} else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) {
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(mapTokenToSqlParameter(token));
return findBy(REFRESH_TOKEN_FILTER, parameters);
}
return null;
}

private SqlParameterValue mapTokenToSqlParameter(String token) {
if (Types.BLOB == tokenColumnType) {
return new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8));
}
return new SqlParameterValue(tokenColumnType, token);
}

private OAuth2Authorization findBy(String filter, List<SqlParameterValue> parameters) {
try (LobCreator lobCreator = getLobHandler().getLobCreator()) {
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
Expand Down Expand Up @@ -349,25 +362,22 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException
builder.attribute(OAuth2ParameterNames.STATE, state);
}

String tokenValue;
Instant tokenIssuedAt;
Instant tokenExpiresAt;
byte[] authorizationCodeValue = this.lobHandler.getBlobAsBytes(rs, "authorization_code_value");
String authorizationCodeValue = getTokenValue(rs, "authorization_code_value");

if (authorizationCodeValue != null) {
tokenValue = new String(authorizationCodeValue, StandardCharsets.UTF_8);
if (StringUtils.hasText(authorizationCodeValue)) {
tokenIssuedAt = rs.getTimestamp("authorization_code_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("authorization_code_expires_at").toInstant();
Map<String, Object> authorizationCodeMetadata = parseMap(rs.getString("authorization_code_metadata"));

OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
tokenValue, tokenIssuedAt, tokenExpiresAt);
authorizationCodeValue, tokenIssuedAt, tokenExpiresAt);
builder.token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata));
}

byte[] accessTokenValue = this.lobHandler.getBlobAsBytes(rs, "access_token_value");
if (accessTokenValue != null) {
tokenValue = new String(accessTokenValue, StandardCharsets.UTF_8);
String accessTokenValue = getTokenValue(rs, "access_token_value");
if (StringUtils.hasText(accessTokenValue)) {
tokenIssuedAt = rs.getTimestamp("access_token_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("access_token_expires_at").toInstant();
Map<String, Object> accessTokenMetadata = parseMap(rs.getString("access_token_metadata"));
Expand All @@ -381,25 +391,23 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException
if (accessTokenScopes != null) {
scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
}
OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, tokenIssuedAt, tokenExpiresAt, scopes);
OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, accessTokenValue, tokenIssuedAt, tokenExpiresAt, scopes);
builder.token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata));
}

byte[] oidcIdTokenValue = this.lobHandler.getBlobAsBytes(rs, "oidc_id_token_value");
if (oidcIdTokenValue != null) {
tokenValue = new String(oidcIdTokenValue, StandardCharsets.UTF_8);
String oidcIdTokenValue = getTokenValue(rs, "oidc_id_token_value");
if (StringUtils.hasText(oidcIdTokenValue)) {
tokenIssuedAt = rs.getTimestamp("oidc_id_token_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant();
Map<String, Object> oidcTokenMetadata = parseMap(rs.getString("oidc_id_token_metadata"));

OidcIdToken oidcToken = new OidcIdToken(
tokenValue, tokenIssuedAt, tokenExpiresAt, (Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt, (Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata));
}

byte[] refreshTokenValue = this.lobHandler.getBlobAsBytes(rs, "refresh_token_value");
if (refreshTokenValue != null) {
tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8);
String refreshTokenValue = getTokenValue(rs, "refresh_token_value");
if (StringUtils.hasText(refreshTokenValue)) {
tokenIssuedAt = rs.getTimestamp("refresh_token_issued_at").toInstant();
tokenExpiresAt = null;
Timestamp refreshTokenExpiresAt = rs.getTimestamp("refresh_token_expires_at");
Expand All @@ -409,12 +417,29 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException
Map<String, Object> refreshTokenMetadata = parseMap(rs.getString("refresh_token_metadata"));

OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
tokenValue, tokenIssuedAt, tokenExpiresAt);
refreshTokenValue, tokenIssuedAt, tokenExpiresAt);
builder.token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata));
}
return builder.build();
}

private String getTokenValue(ResultSet rs, String tokenColumn) throws SQLException {
String tokenValue = null;
if (Types.CLOB == tokenColumnType) {
tokenValue = this.lobHandler.getClobAsString(rs, tokenColumn);
}
if (Types.VARCHAR == tokenColumnType) {
tokenValue = rs.getString(tokenColumn);
}
if (Types.BLOB == tokenColumnType) {
byte[] tokenValueByte = this.lobHandler.getBlobAsBytes(rs, tokenColumn);
if (tokenValueByte != null) {
tokenValue = new String(tokenValueByte, StandardCharsets.UTF_8);
}
}
return tokenValue;
}

public final void setLobHandler(LobHandler lobHandler) {
Assert.notNull(lobHandler, "lobHandler cannot be null");
this.lobHandler = lobHandler;
Expand Down Expand Up @@ -520,12 +545,12 @@ protected final ObjectMapper getObjectMapper() {

private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterList(OAuth2Authorization.Token<T> token) {
List<SqlParameterValue> parameters = new ArrayList<>();
byte[] tokenValue = null;
String tokenValue = null;
Timestamp tokenIssuedAt = null;
Timestamp tokenExpiresAt = null;
String metadata = null;
if (token != null) {
tokenValue = token.getToken().getTokenValue().getBytes(StandardCharsets.UTF_8);
tokenValue = token.getToken().getTokenValue();
if (token.getToken().getIssuedAt() != null) {
tokenIssuedAt = Timestamp.from(token.getToken().getIssuedAt());
}
Expand All @@ -534,7 +559,13 @@ private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterLi
}
metadata = writeMap(token.getMetadata());
}
parameters.add(new SqlParameterValue(Types.BLOB, tokenValue));
if (Types.BLOB == tokenColumnType && StringUtils.hasText(tokenValue)) {
byte[] tokenValueAsBytes = tokenValue.getBytes(StandardCharsets.UTF_8);
parameters.add(new SqlParameterValue(tokenColumnType, tokenValueAsBytes));
} else {
parameters.add(new SqlParameterValue(tokenColumnType, tokenValue));
}

parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenIssuedAt));
parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenExpiresAt));
parameters.add(new SqlParameterValue(Types.VARCHAR, metadata));
Expand All @@ -551,6 +582,23 @@ private String writeMap(Map<String, Object> data) {

}

private static int getColumnDataType(JdbcOperations jdbcOperations, String columnName){
return jdbcOperations.execute((ConnectionCallback<Integer>) con -> {
DatabaseMetaData databaseMetaData = con.getMetaData();
ResultSet rs = databaseMetaData.getColumns(null, null, TABLE_NAME, columnName);
if (rs.next()) {
return rs.getInt("DATA_TYPE");
}
// NOTE: When using HSQL: When a database object is created with one of the CREATE statements if the name is enclosed in double quotes, the exact name is used as the case-normal form.
// But if it is not enclosed in double quotes, the name is converted to uppercase and this uppercase version is stored in the database as the case-normal form
rs = databaseMetaData.getColumns(null, null, TABLE_NAME.toUpperCase(), columnName.toUpperCase());
if (rs.next()) {
return rs.getInt("DATA_TYPE");
}
return Types.NULL;
});
}

private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
private final LobCreator lobCreator;

Expand All @@ -572,6 +620,15 @@ protected void doSetValue(PreparedStatement ps, int parameterPosition, Object ar
this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes);
return;
}
if (paramValue.getSqlType() == Types.CLOB) {
if (paramValue.getValue() != null) {
Assert.isInstanceOf(String.class, paramValue.getValue(),
"Value of clob parameter must be String");
}
String valueString = (String) paramValue.getValue();
this.lobCreator.setClobAsString(ps, parameterPosition, valueString);
return;
}
}
super.doSetValue(ps, parameterPosition, argValue);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2020-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE oauth2_authorization (
id varchar(100) NOT NULL,
registered_client_id varchar(100) NOT NULL,
principal_name varchar(200) NOT NULL,
authorization_grant_type varchar(100) NOT NULL,
attributes varchar(15000) DEFAULT NULL,
state varchar(500) DEFAULT NULL,
authorization_code_value text DEFAULT NULL,
authorization_code_issued_at timestamp DEFAULT NULL,
authorization_code_expires_at timestamp DEFAULT NULL,
authorization_code_metadata varchar(2000) DEFAULT NULL,
access_token_value text DEFAULT NULL,
access_token_issued_at timestamp DEFAULT NULL,
access_token_expires_at timestamp DEFAULT NULL,
access_token_metadata varchar(2000) DEFAULT NULL,
access_token_type varchar(100) DEFAULT NULL,
access_token_scopes varchar(1000) DEFAULT NULL,
oidc_id_token_value text DEFAULT NULL,
oidc_id_token_issued_at timestamp DEFAULT NULL,
oidc_id_token_expires_at timestamp DEFAULT NULL,
oidc_id_token_metadata varchar(2000) DEFAULT NULL,
refresh_token_value text DEFAULT NULL,
refresh_token_issued_at timestamp DEFAULT NULL,
refresh_token_expires_at timestamp DEFAULT NULL,
refresh_token_metadata varchar(2000) DEFAULT NULL,
PRIMARY KEY (id)
);
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import org.springframework.jdbc.support.lob.DefaultLobHandler;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
Expand Down Expand Up @@ -75,6 +76,7 @@
public class JdbcOAuth2AuthorizationServiceTests {
private static final String OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql";
private static final String CUSTOM_OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema.sql";
private static final String OAUTH2_AUTHORIZATION_SCHEMA_CLOB_COLUMN_TYPE_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema-clob-data-type.sql";
private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);
private static final String ID = "id";
Expand Down Expand Up @@ -414,6 +416,37 @@ public void tableDefinitionWhenCustomThenAbleToOverride() {
db.shutdown();
}

@Test
public void tableDefinitionWhenClobSqlTypeThenUpdateAuthorization() {
EmbeddedDatabase db = createDb(OAUTH2_AUTHORIZATION_SCHEMA_CLOB_COLUMN_TYPE_SQL_RESOURCE);
OAuth2AuthorizationService authorizationService =
new JdbcOAuth2AuthorizationService(new JdbcTemplate(db), this.registeredClientRepository);
when(this.registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
.thenReturn(REGISTERED_CLIENT);
OAuth2Authorization originalAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.id(ID)
.principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.token(AUTHORIZATION_CODE)
.build();
authorizationService.save(originalAuthorization);

OAuth2Authorization authorization = authorizationService.findById(
originalAuthorization.getId());
assertThat(authorization).isEqualTo(originalAuthorization);

OAuth2Authorization updatedAuthorization = OAuth2Authorization.from(authorization)
.attribute("custom-name-1", "custom-value-1")
.build();
authorizationService.save(updatedAuthorization);

authorization = authorizationService.findById(
updatedAuthorization.getId());
assertThat(authorization).isEqualTo(updatedAuthorization);
assertThat(authorization).isNotEqualTo(originalAuthorization);
db.shutdown();
}

private static EmbeddedDatabase createDb() {
return createDb(OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE);
}
Expand Down Expand Up @@ -479,11 +512,14 @@ private static final class CustomJdbcOAuth2AuthorizationService extends JdbcOAut

private CustomJdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
RegisteredClientRepository registeredClientRepository) {
super(jdbcOperations, registeredClientRepository);
super(jdbcOperations, registeredClientRepository, new DefaultLobHandler());
setAuthorizationRowMapper(new CustomOAuth2AuthorizationRowMapper(registeredClientRepository));
setAuthorizationParametersMapper(new CustomOAuth2AuthorizationParametersMapper());

}



@Override
public void save(OAuth2Authorization authorization) {
List<SqlParameterValue> parameters = getAuthorizationParametersMapper().apply(authorization);
Expand Down
Loading