Skip to content

Commit bf953ad

Browse files
Jithendar12ritiktrianz
authored andcommitted
Add Support for Case-Insensitive Search in Athena DB2 Connector (awslabs#2784)
1 parent fcd779d commit bf953ad

File tree

5 files changed

+214
-31
lines changed

5 files changed

+214
-31
lines changed

athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MetadataHandler.java

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,13 @@
3737
import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest;
3838
import com.amazonaws.athena.connector.lambda.metadata.ListSchemasRequest;
3939
import com.amazonaws.athena.connector.lambda.metadata.ListSchemasResponse;
40-
import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest;
41-
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
4240
import com.amazonaws.athena.connector.lambda.metadata.optimizations.DataSourceOptimizations;
4341
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
4442
import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.ComplexExpressionPushdownSubType;
4543
import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.FilterPushdownSubType;
4644
import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.LimitPushdownSubType;
4745
import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.TopNPushdownSubType;
46+
import com.amazonaws.athena.connectors.db2.resolver.Db2JDBCCaseResolver;
4847
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
4948
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
5049
import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory;
@@ -53,6 +52,7 @@
5352
import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter;
5453
import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler;
5554
import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder;
55+
import com.amazonaws.athena.connectors.jdbc.resolver.JDBCCaseResolver;
5656
import com.google.common.annotations.VisibleForTesting;
5757
import com.google.common.collect.ImmutableMap;
5858
import com.google.common.collect.ImmutableSet;
@@ -68,6 +68,7 @@
6868
import java.sql.Connection;
6969
import java.sql.PreparedStatement;
7070
import java.sql.ResultSet;
71+
import java.sql.SQLException;
7172
import java.sql.Statement;
7273
import java.util.ArrayList;
7374
import java.util.Arrays;
@@ -121,18 +122,19 @@ public Db2MetadataHandler(
121122
JdbcConnectionFactory jdbcConnectionFactory,
122123
java.util.Map<String, String> configOptions)
123124
{
124-
super(databaseConnectionConfig, jdbcConnectionFactory, configOptions);
125+
super(databaseConnectionConfig, jdbcConnectionFactory, configOptions, new Db2JDBCCaseResolver(Db2Constants.NAME));
125126
}
126127

127128
@VisibleForTesting
128129
protected Db2MetadataHandler(
129-
DatabaseConnectionConfig databaseConnectionConfig,
130-
SecretsManagerClient secretsManager,
131-
AthenaClient athena,
132-
JdbcConnectionFactory jdbcConnectionFactory,
133-
java.util.Map<String, String> configOptions)
130+
DatabaseConnectionConfig databaseConnectionConfig,
131+
SecretsManagerClient secretsManager,
132+
AthenaClient athena,
133+
JdbcConnectionFactory jdbcConnectionFactory,
134+
java.util.Map<String, String> configOptions,
135+
JDBCCaseResolver caseResolver)
134136
{
135-
super(databaseConnectionConfig, secretsManager, athena, jdbcConnectionFactory, configOptions);
137+
super(databaseConnectionConfig, secretsManager, athena, jdbcConnectionFactory, configOptions, caseResolver);
136138
}
137139

138140
/**
@@ -152,21 +154,20 @@ public ListSchemasResponse doListSchemaNames(final BlockAllocator blockAllocator
152154
}
153155

154156
/**
155-
* Overridden this method to fetch table(s) for selected schema in Athena Data window.
157+
* Overridden the base class method to provide DB2-specific table listing functionality.
156158
*
157-
* @param blockAllocator
158-
* @param listTablesRequest
159-
* @return
159+
* @param connection The JDBC connection to use for querying DB2
160+
* @param schemaName The name of the schema to list tables from
161+
* @return A list of {@link TableName} objects representing the tables and views in the specified schema
162+
* @throws SQLException if there is an error executing the query or processing the results
160163
*/
161164
@Override
162-
public ListTablesResponse doListTables(final BlockAllocator blockAllocator, final ListTablesRequest listTablesRequest) throws Exception
165+
protected List<TableName> listTables(Connection connection, String schemaName) throws SQLException
163166
{
164-
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
165-
LOGGER.info("{}: List table names for Catalog {}, Schema {}", listTablesRequest.getQueryId(), listTablesRequest.getCatalogName(), listTablesRequest.getSchemaName());
166-
List<String> tableNames = getTableList(connection, Db2Constants.QRY_TO_LIST_TABLES_AND_VIEWS, listTablesRequest.getSchemaName());
167-
List<TableName> tables = tableNames.stream().map(tableName -> new TableName(listTablesRequest.getSchemaName(), tableName)).collect(Collectors.toList());
168-
return new ListTablesResponse(listTablesRequest.getCatalogName(), tables, null);
169-
}
167+
List<String> tableNames = getTableList(connection, schemaName);
168+
return tableNames.stream()
169+
.map(tableName -> new TableName(schemaName, tableName))
170+
.collect(Collectors.toList());
170171
}
171172

172173
/**
@@ -395,18 +396,17 @@ private List<String> getSchemaList(final Connection connection, String query) th
395396
* Logic to fetch table name(s) for given schema. Through jdbc call and executing sql query pulling
396397
* all the table names from Db2 for a given schema.
397398
*
398-
* @param connection
399-
* @param query
400-
* @param schemaName
401-
* @return List<String>
402-
* @throws Exception
399+
* @param connection The JDBC connection to use for querying DB2
400+
* @param schemaName The name of the schema to list tables from
401+
* @return List of table names in the specified schema
402+
* @throws SQLException if any error occurs while executing the query or processing results
403403
*/
404-
private List<String> getTableList(final Connection connection, String query, String schemaName) throws Exception
404+
private List<String> getTableList(final Connection connection, String schemaName) throws SQLException
405405
{
406406
List<String> list = new ArrayList<>();
407-
try (PreparedStatement ps = connection.prepareStatement(query)) {
407+
try (PreparedStatement ps = connection.prepareStatement(Db2Constants.QRY_TO_LIST_TABLES_AND_VIEWS)) {
408408
ps.setString(1, schemaName);
409-
try (ResultSet rs = ps.executeQuery();) {
409+
try (ResultSet rs = ps.executeQuery()) {
410410
while (rs.next()) {
411411
list.add(rs.getString("NAME"));
412412
}

athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxMetadataHandler.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package com.amazonaws.athena.connectors.db2;
2121

22+
import com.amazonaws.athena.connectors.db2.resolver.Db2JDBCCaseResolver;
2223
import com.amazonaws.athena.connectors.jdbc.MultiplexingJdbcMetadataHandler;
2324
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
2425
import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory;
@@ -58,6 +59,6 @@ public Db2MuxMetadataHandler(java.util.Map<String, String> configOptions)
5859
protected Db2MuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory,
5960
Map<String, JdbcMetadataHandler> metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map<String, String> configOptions)
6061
{
61-
super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions);
62+
super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions, new Db2JDBCCaseResolver(Db2Constants.NAME));
6263
}
6364
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*-
2+
* #%L
3+
* athena-db2
4+
* %%
5+
* Copyright (C) 2019 - 2025 Amazon Web Services
6+
* %%
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* #L%
19+
*/
20+
package com.amazonaws.athena.connectors.db2.resolver;
21+
22+
import com.amazonaws.athena.connectors.jdbc.resolver.DefaultJDBCCaseResolver;
23+
24+
import java.util.List;
25+
26+
public class Db2JDBCCaseResolver
27+
extends DefaultJDBCCaseResolver
28+
{
29+
private static final String TABLE_NAME_QUERY_TEMPLATE = "SELECT TABNAME FROM SYSCAT.TABLES WHERE TABSCHEMA = ? AND LOWER(TABNAME) = ?";
30+
private static final String SCHEMA_NAME_QUERY_TEMPLATE = "SELECT SCHEMANAME FROM SYSCAT.SCHEMATA WHERE LOWER(SCHEMANAME) = ?";
31+
32+
private static final String SCHEMA_NAME_COLUMN_KEY = "SCHEMANAME";
33+
private static final String TABLE_NAME_COLUMN_KEY = "TABNAME";
34+
35+
public Db2JDBCCaseResolver(String sourceType)
36+
{
37+
super(sourceType, FederationSDKCasingMode.UPPER, FederationSDKCasingMode.LOWER);
38+
}
39+
40+
@Override
41+
protected String getCaseInsensitivelySchemaNameQueryTemplate()
42+
{
43+
return SCHEMA_NAME_QUERY_TEMPLATE;
44+
}
45+
46+
@Override
47+
protected String getCaseInsensitivelySchemaNameColumnKey()
48+
{
49+
return SCHEMA_NAME_COLUMN_KEY;
50+
}
51+
52+
@Override
53+
protected List<String> getCaseInsensitivelyTableNameQueryTemplate()
54+
{
55+
return List.of(TABLE_NAME_QUERY_TEMPLATE);
56+
}
57+
58+
@Override
59+
protected String getCaseInsensitivelyTableNameColumnKey()
60+
{
61+
return TABLE_NAME_COLUMN_KEY;
62+
}
63+
}

athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2MetadataHandlerTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest;
3838
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
3939
import com.amazonaws.athena.connector.lambda.security.FederatedIdentity;
40+
import com.amazonaws.athena.connectors.db2.resolver.Db2JDBCCaseResolver;
4041
import com.amazonaws.athena.connectors.jdbc.TestBase;
4142
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
4243
import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory;
@@ -69,6 +70,7 @@
6970
import java.util.concurrent.atomic.AtomicInteger;
7071
import java.util.stream.Collectors;
7172

73+
import static com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest.UNLIMITED_PAGE_SIZE_VALUE;
7274
import static com.amazonaws.athena.connectors.db2.Db2Constants.PARTITION_NUMBER;
7375
import static org.mockito.ArgumentMatchers.any;
7476
import static org.mockito.ArgumentMatchers.nullable;
@@ -96,7 +98,7 @@ public void setup() throws Exception {
9698
this.secretsManager = Mockito.mock(SecretsManagerClient.class);
9799
this.athena = Mockito.mock(AthenaClient.class);
98100
Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}").build());
99-
this.db2MetadataHandler = new Db2MetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of());
101+
this.db2MetadataHandler = new Db2MetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of(), new Db2JDBCCaseResolver(Db2Constants.NAME));
100102
this.federatedIdentity = Mockito.mock(FederatedIdentity.class);
101103
this.blockAllocator = new BlockAllocatorImpl();
102104
}
@@ -337,7 +339,7 @@ public void doListSchemaNames() throws Exception {
337339
@Test
338340
public void doListTables() throws Exception {
339341
String schemaName = "TESTSCHEMA";
340-
ListTablesRequest listTablesRequest = new ListTablesRequest(federatedIdentity, "queryId", "testCatalog", schemaName, null, 0);
342+
ListTablesRequest listTablesRequest = new ListTablesRequest(federatedIdentity, "queryId", "testCatalog", schemaName, null, UNLIMITED_PAGE_SIZE_VALUE);
341343

342344
PreparedStatement stmt = Mockito.mock(PreparedStatement.class);
343345
Mockito.when(this.connection.prepareStatement(Db2Constants.QRY_TO_LIST_TABLES_AND_VIEWS)).thenReturn(stmt);
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*-
2+
* #%L
3+
* athena-db2
4+
* %%
5+
* Copyright (C) 2019 - 2025 Amazon Web Services
6+
* %%
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* #L%
19+
*/
20+
package com.amazonaws.athena.connectors.db2.resolver;
21+
22+
import com.amazonaws.athena.connector.lambda.domain.TableName;
23+
import com.amazonaws.athena.connector.lambda.resolver.CaseResolver;
24+
import com.amazonaws.athena.connectors.db2.Db2Constants;
25+
import com.amazonaws.athena.connectors.jdbc.TestBase;
26+
import com.amazonaws.athena.connectors.jdbc.resolver.DefaultJDBCCaseResolver;
27+
import org.junit.Before;
28+
import org.junit.Test;
29+
import org.mockito.Mockito;
30+
31+
import java.sql.Connection;
32+
import java.sql.PreparedStatement;
33+
import java.sql.ResultSet;
34+
import java.sql.SQLException;
35+
import java.sql.Types;
36+
import java.util.Map;
37+
import java.util.concurrent.atomic.AtomicInteger;
38+
39+
import static com.amazonaws.athena.connector.lambda.resolver.CaseResolver.CASING_MODE_CONFIGURATION_KEY;
40+
import static org.junit.Assert.assertEquals;
41+
import static org.mockito.ArgumentMatchers.any;
42+
import static org.mockito.Mockito.when;
43+
44+
public class Db2JDBCCaseResolverTest extends TestBase
45+
{
46+
private Connection mockConnection;
47+
private PreparedStatement preparedStatement;
48+
49+
@Before
50+
public void setup() throws SQLException
51+
{
52+
mockConnection = Mockito.mock(Connection.class);
53+
preparedStatement = Mockito.mock(PreparedStatement.class);
54+
when(mockConnection.prepareStatement(any())).thenReturn(preparedStatement);
55+
}
56+
57+
@Test
58+
public void testCaseInsensitiveCaseOnName() throws SQLException
59+
{
60+
String schemaName = "oRaNgE";
61+
String tableName = "ApPlE";
62+
DefaultJDBCCaseResolver resolver = new Db2JDBCCaseResolver(Db2Constants.NAME);
63+
64+
// Mock schema name result
65+
String[] schemaCols = {"SCHEMANAME"};
66+
int[] schemaTypes = {Types.VARCHAR};
67+
Object[][] schemaData = {{schemaName.toLowerCase()}};
68+
69+
ResultSet schemaResultSet = mockResultSet(schemaCols, schemaTypes, schemaData, new AtomicInteger(-1));
70+
when(preparedStatement.executeQuery()).thenReturn(schemaResultSet);
71+
72+
String adjustedSchemaName = resolver.getAdjustedSchemaNameString(mockConnection, schemaName, Map.of(
73+
CASING_MODE_CONFIGURATION_KEY, CaseResolver.FederationSDKCasingMode.CASE_INSENSITIVE_SEARCH.name()));
74+
assertEquals(schemaName.toLowerCase(), adjustedSchemaName);
75+
76+
// Mock table name result
77+
String[] tableCols = {"TABNAME"};
78+
int[] tableTypes = {Types.VARCHAR};
79+
Object[][] tableData = {{tableName.toUpperCase()}};
80+
81+
ResultSet tableResultSet = mockResultSet(tableCols, tableTypes, tableData, new AtomicInteger(-1));
82+
when(preparedStatement.executeQuery()).thenReturn(tableResultSet);
83+
84+
String adjustedTableName = resolver.getAdjustedTableNameString(mockConnection, schemaName, tableName, Map.of(
85+
CASING_MODE_CONFIGURATION_KEY, CaseResolver.FederationSDKCasingMode.CASE_INSENSITIVE_SEARCH.name()));
86+
assertEquals(tableName.toUpperCase(), adjustedTableName);
87+
}
88+
89+
@Test
90+
public void testCaseInsensitiveCaseOnObject() throws SQLException
91+
{
92+
String schemaName = "oRaNgE";
93+
String tableName = "ApPlE";
94+
DefaultJDBCCaseResolver resolver = new Db2JDBCCaseResolver(Db2Constants.NAME);
95+
96+
// Mock schema and table result sets
97+
ResultSet schemaResultSet = mockResultSet(
98+
new String[]{"SCHEMANAME"},
99+
new int[]{Types.VARCHAR},
100+
new Object[][]{{schemaName.toLowerCase()}},
101+
new AtomicInteger(-1));
102+
103+
ResultSet tableResultSet = mockResultSet(
104+
new String[]{"TABNAME"},
105+
new int[]{Types.VARCHAR},
106+
new Object[][]{{tableName.toUpperCase()}},
107+
new AtomicInteger(-1));
108+
109+
when(preparedStatement.executeQuery()).thenReturn(schemaResultSet).thenReturn(tableResultSet);
110+
111+
TableName adjusted = resolver.getAdjustedTableNameObject(
112+
mockConnection,
113+
new TableName(schemaName, tableName),
114+
Map.of(CASING_MODE_CONFIGURATION_KEY, CaseResolver.FederationSDKCasingMode.CASE_INSENSITIVE_SEARCH.name()));
115+
assertEquals(new TableName(schemaName.toLowerCase(), tableName.toUpperCase()), adjusted);
116+
}
117+
}

0 commit comments

Comments
 (0)