From 54f3dc0a7d7789de623485b1c43d002cf93a7f8b Mon Sep 17 00:00:00 2001 From: Jithendar12 Date: Mon, 26 May 2025 11:22:35 +0530 Subject: [PATCH] Add unit tests in Synapse Connector. --- .../SynapseEnvironmentPropertiesTest.java | 60 ++ .../synapse/SynapseMetadataHandlerTest.java | 305 +++++++-- .../synapse/SynapseRecordHandlerTest.java | 617 ++++++++++++++++-- .../resolver/SynapseJDBCCaseResolverTest.java | 117 ++++ 4 files changed, 984 insertions(+), 115 deletions(-) create mode 100644 athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseEnvironmentPropertiesTest.java create mode 100644 athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/resolver/SynapseJDBCCaseResolverTest.java diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseEnvironmentPropertiesTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseEnvironmentPropertiesTest.java new file mode 100644 index 0000000000..ffcae0f3b5 --- /dev/null +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseEnvironmentPropertiesTest.java @@ -0,0 +1,60 @@ +/*- + * #%L + * athena-synapse + * %% + * Copyright (C) 2019 - 2025 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.synapse; + +import org.junit.Before; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DATABASE; +import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT; +import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.HOST; +import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.PORT; +import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.SECRET_NAME; +import static org.junit.Assert.assertEquals; + +public class SynapseEnvironmentPropertiesTest +{ + private Map connectionProperties; + private SynapseEnvironmentProperties synapseEnvironmentProperties; + + @Before + public void setUp() + { + connectionProperties = new HashMap<>(); + connectionProperties.put(HOST, "test.sql.azuresynapse.net"); + connectionProperties.put(PORT, "1433"); + connectionProperties.put(DATABASE, "testdb"); + connectionProperties.put(SECRET_NAME, "synapse-secret"); + + synapseEnvironmentProperties = new SynapseEnvironmentProperties(); + } + + @Test + public void connectionPropertiesToEnvironment_WithValidProperties_ReturnsCorrectConnectionString() + { + Map synapseConnectionProperties = synapseEnvironmentProperties.connectionPropertiesToEnvironment(connectionProperties); + + String expectedConnectionString = "synapse://jdbc:sqlserver://test.sql.azuresynapse.net:1433;databaseName=testdb;${synapse-secret}"; + assertEquals(expectedConnectionString, synapseConnectionProperties.get(DEFAULT)); + } +} diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java index 9239c2ae1d..368fd78208 100644 --- a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java @@ -19,6 +19,7 @@ */ package com.amazonaws.athena.connectors.synapse; +import com.amazonaws.athena.connector.credentials.CredentialsProvider; import com.amazonaws.athena.connector.lambda.data.BlockAllocator; import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl; import com.amazonaws.athena.connector.lambda.data.BlockUtils; @@ -27,6 +28,8 @@ import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; +import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest; +import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse; import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest; @@ -35,13 +38,14 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest; import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse; +import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connectors.jdbc.TestBase; import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; -import com.amazonaws.athena.connector.credentials.CredentialsProvider; import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.synapse.resolver.SynapseJDBCCaseResolver; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; @@ -57,15 +61,19 @@ import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import java.sql.Types; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -75,6 +83,9 @@ import static com.amazonaws.athena.connectors.synapse.SynapseMetadataHandler.PARTITION_COLUMN; import static com.amazonaws.athena.connectors.synapse.SynapseMetadataHandler.PARTITION_NUMBER; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.nullable; @@ -87,8 +98,20 @@ public class SynapseMetadataHandlerTest extends TestBase { private static final Logger logger = LoggerFactory.getLogger(SynapseMetadataHandlerTest.class); - private DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", SynapseConstants.NAME, - "synapse://jdbc:sqlserver://hostname;databaseName=fakedatabase"); + private static final String TEST_CATALOG = "testCatalog"; + private static final String TEST_SCHEMA = "TESTSCHEMA"; + private static final String TEST_TABLE = "TESTTABLE"; + private static final String TEST_QUERY_ID = "testQueryId"; + private static final String TEST_DB_NAME = "fakedatabase"; + private static final String TEST_HOSTNAME = "hostname"; + private static final String TEST_SECRET = "testSecret"; + private static final String TEST_USER = "testUser"; + private static final String TEST_PASS = "testPassword"; + private static final String TEST_JDBC_URL = String.format("synapse://jdbc:sqlserver://%s;databaseName=%s", TEST_HOSTNAME, TEST_DB_NAME); + + private final DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig(TEST_CATALOG, SynapseConstants.NAME, + TEST_JDBC_URL); + private static final Schema PARTITION_SCHEMA = SchemaBuilder.newBuilder().addField(PARTITION_NUMBER, org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build(); private SynapseMetadataHandler synapseMetadataHandler; private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; @@ -103,20 +126,25 @@ public void setup() System.setProperty("aws.region", "us-east-1"); this.jdbcConnectionFactory = mock(JdbcConnectionFactory.class, RETURNS_DEEP_STUBS); this.connection = mock(Connection.class, RETURNS_DEEP_STUBS); - logger.info(" this.connection.." + this.connection); + logger.info(" this.connection..{}", this.connection); when(this.jdbcConnectionFactory.getConnection(nullable(CredentialsProvider.class))).thenReturn(this.connection); this.secretsManager = mock(SecretsManagerClient.class); this.athena = mock(AthenaClient.class); - when(this.secretsManager.getSecretValue(eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}").build()); - this.synapseMetadataHandler = new SynapseMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of(), new SynapseJDBCCaseResolver(SynapseConstants.NAME)); + when(this.secretsManager.getSecretValue(eq(GetSecretValueRequest.builder().secretId(TEST_SECRET).build()))) + .thenReturn(GetSecretValueResponse.builder() + .secretString(String.format("{\"user\": \"%s\", \"password\": \"%s\"}", TEST_USER, TEST_PASS)) + .build()); + this.synapseMetadataHandler = new SynapseMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, + this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of(), new SynapseJDBCCaseResolver(SynapseConstants.NAME)); this.federatedIdentity = mock(FederatedIdentity.class); } @Test - public void getPartitionSchema() { + public void getPartitionSchema() + { assertEquals(SchemaBuilder.newBuilder() .addField(PARTITION_NUMBER, org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build(), - this.synapseMetadataHandler.getPartitionSchema("testCatalogName")); + this.synapseMetadataHandler.getPartitionSchema(TEST_CATALOG)); } @Test @@ -125,11 +153,11 @@ public void doGetTableLayout() { BlockAllocator blockAllocator = new BlockAllocatorImpl(); Constraints constraints = mock(Constraints.class); - TableName tableName = new TableName("testSchema", "testTable"); + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); Schema partitionSchema = this.synapseMetadataHandler.getPartitionSchema("testCatalogName"); Set partitionCols = partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()); - GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, constraints, partitionSchema, partitionCols); + GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, TEST_QUERY_ID, "testCatalogName", tableName, constraints, partitionSchema, partitionCols); String[] columns = {"ROW_COUNT", PARTITION_NUMBER, PARTITION_COLUMN, "PARTITION_BOUNDARY_VALUE"}; int[] types = {Types.INTEGER, Types.VARCHAR, Types.VARCHAR, Types.VARCHAR}; @@ -161,10 +189,10 @@ public void doGetTableLayoutWithNoPartitions() { BlockAllocator blockAllocator = new BlockAllocatorImpl(); Constraints constraints = mock(Constraints.class); - TableName tableName = new TableName("testSchema", "testTable"); + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); Schema partitionSchema = this.synapseMetadataHandler.getPartitionSchema("testCatalogName"); Set partitionCols = partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()); - GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, constraints, partitionSchema, partitionCols); + GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, TEST_QUERY_ID, "testCatalogName", tableName, constraints, partitionSchema, partitionCols); Object[][] values = {{}}; ResultSet resultSet = mockResultSet(new String[]{"ROW_COUNT"}, new int[]{Types.INTEGER}, values, new AtomicInteger(-1)); @@ -196,10 +224,10 @@ public void doGetTableLayoutWithSQLException() throws Exception { Constraints constraints = mock(Constraints.class); - TableName tableName = new TableName("testSchema", "testTable"); + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); Schema partitionSchema = this.synapseMetadataHandler.getPartitionSchema("testCatalogName"); Set partitionCols = partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()); - GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, constraints, partitionSchema, partitionCols); + GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, TEST_QUERY_ID, "testCatalogName", tableName, constraints, partitionSchema, partitionCols); Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS); JdbcConnectionFactory jdbcConnectionFactory = mock(JdbcConnectionFactory.class); @@ -216,7 +244,7 @@ public void doGetSplits() { BlockAllocator blockAllocator = new BlockAllocatorImpl(); Constraints constraints = mock(Constraints.class); - TableName tableName = new TableName("testSchema", "testTable"); + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); String[] columns = {"ROW_COUNT", PARTITION_NUMBER, PARTITION_COLUMN, "PARTITION_BOUNDARY_VALUE"}; int[] types = {Types.INTEGER, Types.VARCHAR, Types.VARCHAR, Types.VARCHAR}; @@ -230,12 +258,12 @@ public void doGetSplits() Schema partitionSchema = this.synapseMetadataHandler.getPartitionSchema("testCatalogName"); Set partitionCols = partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()); - GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, constraints, partitionSchema, partitionCols); + GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, TEST_QUERY_ID, "testCatalogName", tableName, constraints, partitionSchema, partitionCols); GetTableLayoutResponse getTableLayoutResponse = this.synapseMetadataHandler.doGetTableLayout(blockAllocator, getTableLayoutRequest); BlockAllocator splitBlockAllocator = new BlockAllocatorImpl(); - GetSplitsRequest getSplitsRequest = new GetSplitsRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, getTableLayoutResponse.getPartitions(), new ArrayList<>(partitionCols), constraints, null); + GetSplitsRequest getSplitsRequest = new GetSplitsRequest(this.federatedIdentity, TEST_QUERY_ID, "testCatalogName", tableName, getTableLayoutResponse.getPartitions(), new ArrayList<>(partitionCols), constraints, null); GetSplitsResponse getSplitsResponse = this.synapseMetadataHandler.doGetSplits(splitBlockAllocator, getSplitsRequest); // TODO: Not sure why this is a set of maps, but I'm not going to change it @@ -274,7 +302,7 @@ public void doGetSplitsWithNoPartition() { BlockAllocator blockAllocator = new BlockAllocatorImpl(); Constraints constraints = mock(Constraints.class); - TableName tableName = new TableName("testSchema", "testTable"); + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); Object[][] values = {{}}; ResultSet resultSet = mockResultSet(new String[]{"ROW_COUNT"}, new int[]{Types.INTEGER}, values, new AtomicInteger(-1)); @@ -285,12 +313,12 @@ public void doGetSplitsWithNoPartition() Schema partitionSchema = this.synapseMetadataHandler.getPartitionSchema("testCatalogName"); Set partitionCols = partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()); - GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, constraints, partitionSchema, partitionCols); + GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, TEST_QUERY_ID, "testCatalogName", tableName, constraints, partitionSchema, partitionCols); GetTableLayoutResponse getTableLayoutResponse = this.synapseMetadataHandler.doGetTableLayout(blockAllocator, getTableLayoutRequest); BlockAllocator splitBlockAllocator = new BlockAllocatorImpl(); - GetSplitsRequest getSplitsRequest = new GetSplitsRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, getTableLayoutResponse.getPartitions(), new ArrayList<>(partitionCols), constraints, null); + GetSplitsRequest getSplitsRequest = new GetSplitsRequest(this.federatedIdentity, TEST_QUERY_ID, "testCatalogName", tableName, getTableLayoutResponse.getPartitions(), new ArrayList<>(partitionCols), constraints, null); GetSplitsResponse getSplitsResponse = this.synapseMetadataHandler.doGetSplits(splitBlockAllocator, getSplitsRequest); Set> expectedSplits = new HashSet<>(); @@ -306,10 +334,10 @@ public void doGetSplitsContinuation() { BlockAllocator blockAllocator = new BlockAllocatorImpl(); Constraints constraints = mock(Constraints.class); - TableName tableName = new TableName("testSchema", "testTable"); + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); Schema partitionSchema = this.synapseMetadataHandler.getPartitionSchema("testCatalogName"); Set partitionCols = partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()); - GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, constraints, partitionSchema, partitionCols); + GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, TEST_QUERY_ID, "testCatalogName", tableName, constraints, partitionSchema, partitionCols); String[] columns = {"ROW_COUNT", PARTITION_NUMBER, PARTITION_COLUMN, "PARTITION_BOUNDARY_VALUE"}; int[] types = {Types.INTEGER, Types.VARCHAR, Types.VARCHAR, Types.VARCHAR}; @@ -324,7 +352,7 @@ public void doGetSplitsContinuation() GetTableLayoutResponse getTableLayoutResponse = this.synapseMetadataHandler.doGetTableLayout(blockAllocator, getTableLayoutRequest); BlockAllocator splitBlockAllocator = new BlockAllocatorImpl(); - GetSplitsRequest getSplitsRequest = new GetSplitsRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, getTableLayoutResponse.getPartitions(), new ArrayList<>(partitionCols), constraints, "2"); + GetSplitsRequest getSplitsRequest = new GetSplitsRequest(this.federatedIdentity, TEST_QUERY_ID, "testCatalogName", tableName, getTableLayoutResponse.getPartitions(), new ArrayList<>(partitionCols), constraints, "2"); GetSplitsResponse getSplitsResponse = this.synapseMetadataHandler.doGetSplits(splitBlockAllocator, getSplitsRequest); Set> expectedSplits = com.google.common.collect.ImmutableSet.of( @@ -346,8 +374,6 @@ public void doGetSplitsContinuation() public void doGetTable() throws Exception { - Schema PARTITION_SCHEMA = SchemaBuilder.newBuilder().addField(PARTITION_NUMBER, org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build(); - BlockAllocator blockAllocator = new BlockAllocatorImpl(); String[] schema = {"DATA_TYPE", "COLUMN_NAME", "PRECISION", "SCALE"}; int[] types = {Types.INTEGER, Types.VARCHAR, Types.VARCHAR, Types.VARCHAR}; @@ -375,15 +401,15 @@ public void doGetTable() when(connection.getMetaData().getURL()).thenReturn("jdbc:sqlserver://hostname;databaseName=fakedatabase"); - TableName inputTableName = new TableName("TESTSCHEMA", "TESTTABLE"); - when(connection.getCatalog()).thenReturn("testCatalog"); - when(connection.getMetaData().getColumns("testCatalog", inputTableName.getSchemaName(), inputTableName.getTableName(), null)).thenReturn(resultSet2); + TableName inputTableName = new TableName(TEST_SCHEMA, TEST_TABLE); + when(connection.getCatalog()).thenReturn(TEST_CATALOG); + when(connection.getMetaData().getColumns(TEST_CATALOG, inputTableName.getSchemaName(), inputTableName.getTableName(), null)).thenReturn(resultSet2); GetTableResponse getTableResponse = this.synapseMetadataHandler.doGetTable( - blockAllocator, new GetTableRequest(this.federatedIdentity, "testQueryId", "testCatalog", inputTableName, Collections.emptyMap())); + blockAllocator, new GetTableRequest(this.federatedIdentity, TEST_QUERY_ID, TEST_CATALOG, inputTableName, Collections.emptyMap())); assertEquals(expected, getTableResponse.getSchema()); assertEquals(inputTableName, getTableResponse.getTableName()); - assertEquals("testCatalog", getTableResponse.getCatalogName()); + assertEquals(TEST_CATALOG, getTableResponse.getCatalogName()); } @Test @@ -409,22 +435,21 @@ public void doDataTypeConversion() when(connection.getMetaData().getURL()).thenReturn("jdbc:sqlserver://hostname-ondemand;databaseName=fakedatabase"); - TableName inputTableName = new TableName("TESTSCHEMA", "TESTTABLE"); - when(connection.getCatalog()).thenReturn("testCatalog"); - when(connection.getMetaData().getColumns("testCatalog", inputTableName.getSchemaName(), inputTableName.getTableName(), null)).thenReturn(resultSet2); + TableName inputTableName = new TableName(TEST_SCHEMA, TEST_TABLE); + when(connection.getCatalog()).thenReturn(TEST_CATALOG); + when(connection.getMetaData().getColumns(TEST_CATALOG, inputTableName.getSchemaName(), inputTableName.getTableName(), null)).thenReturn(resultSet2); GetTableResponse getTableResponse = this.synapseMetadataHandler.doGetTable( - blockAllocator, new GetTableRequest(this.federatedIdentity, "testQueryId", "testCatalog", inputTableName, Collections.emptyMap())); + blockAllocator, new GetTableRequest(this.federatedIdentity, TEST_QUERY_ID, TEST_CATALOG, inputTableName, Collections.emptyMap())); assertEquals(inputTableName, getTableResponse.getTableName()); - assertEquals("testCatalog", getTableResponse.getCatalogName()); + assertEquals(TEST_CATALOG, getTableResponse.getCatalogName()); } @Test public void doListTables() throws Exception { BlockAllocator blockAllocator = new BlockAllocatorImpl(); - String schemaName = "TESTSCHEMA"; - ListTablesRequest listTablesRequest = new ListTablesRequest(federatedIdentity, "queryId", "testCatalog", schemaName, null, 3); + ListTablesRequest listTablesRequest = new ListTablesRequest(federatedIdentity, TEST_QUERY_ID, TEST_CATALOG, TEST_SCHEMA, null, 3); DatabaseMetaData mockDatabaseMetaData = mock(DatabaseMetaData.class); ResultSet mockResultSet = mock(ResultSet.class); @@ -434,23 +459,215 @@ public void doListTables() throws Exception when(mockResultSet.next()).thenReturn(true).thenReturn(true).thenReturn(true).thenReturn(false); when(mockResultSet.getString(3)).thenReturn("TESTTABLE").thenReturn("testtable").thenReturn("testTABLE"); - when(mockResultSet.getString(2)).thenReturn(schemaName); + when(mockResultSet.getString(2)).thenReturn(TEST_SCHEMA); mockStatic(JDBCUtil.class); - when(JDBCUtil.getSchemaTableName(mockResultSet)).thenReturn(new TableName("TESTSCHEMA", "TESTTABLE")) - .thenReturn(new TableName("TESTSCHEMA", "testtable")) - .thenReturn(new TableName("TESTSCHEMA", "testTABLE")); + when(JDBCUtil.getSchemaTableName(mockResultSet)).thenReturn(new TableName(TEST_SCHEMA, TEST_TABLE)) + .thenReturn(new TableName(TEST_SCHEMA, "testtable")) + .thenReturn(new TableName(TEST_SCHEMA, "testTABLE")); when(this.jdbcConnectionFactory.getConnection(any())).thenReturn(connection); ListTablesResponse listTablesResponse = this.synapseMetadataHandler.doListTables(blockAllocator, listTablesRequest); TableName[] expectedTables = { - new TableName("TESTSCHEMA", "TESTTABLE"), - new TableName("TESTSCHEMA", "testTABLE"), - new TableName("TESTSCHEMA", "testtable") + new TableName(TEST_SCHEMA, TEST_TABLE), + new TableName(TEST_SCHEMA, "testTABLE"), + new TableName(TEST_SCHEMA, "testtable") }; assertEquals(Arrays.toString(expectedTables), listTablesResponse.getTables().toString()); } + + @Test + public void convertDatasourceTypeToArrow_WithSynapseSpecificTypes_ReturnsCorrectArrowTypes() throws SQLException { + ResultSetMetaData metaData = mock(ResultSetMetaData.class); + Map configOptions = new HashMap<>(); + int precision = 0; + + // Map of Synapse data type -> expected ArrowType + Map expectedMappings = new HashMap<>(); + expectedMappings.put("BIT", org.apache.arrow.vector.types.Types.MinorType.TINYINT.getType()); + expectedMappings.put("TINYINT", org.apache.arrow.vector.types.Types.MinorType.SMALLINT.getType()); + expectedMappings.put("NUMERIC", org.apache.arrow.vector.types.Types.MinorType.FLOAT8.getType()); + expectedMappings.put("SMALLMONEY", org.apache.arrow.vector.types.Types.MinorType.FLOAT8.getType()); + expectedMappings.put("DATE", org.apache.arrow.vector.types.Types.MinorType.DATEDAY.getType()); + expectedMappings.put("DATETIME", org.apache.arrow.vector.types.Types.MinorType.DATEMILLI.getType()); + expectedMappings.put("DATETIME2", org.apache.arrow.vector.types.Types.MinorType.DATEMILLI.getType()); + expectedMappings.put("SMALLDATETIME", org.apache.arrow.vector.types.Types.MinorType.DATEMILLI.getType()); + expectedMappings.put("DATETIMEOFFSET", org.apache.arrow.vector.types.Types.MinorType.DATEMILLI.getType()); + + int index = 1; + for (Map.Entry entry : expectedMappings.entrySet()) { + String synapseType = entry.getKey(); + ArrowType expectedArrowType = entry.getValue(); + + when(metaData.getColumnTypeName(index)).thenReturn(synapseType); + + Optional actual = synapseMetadataHandler.convertDatasourceTypeToArrow(index, precision, configOptions, metaData); + + assertTrue("Expected ArrowType to be present", actual.isPresent()); + assertEquals(expectedArrowType, actual.get()); + + index++; + } + } + + @Test + public void doGetTable_WithAzureServerless_ReturnsCorrectSchema() throws Exception { + BlockAllocator blockAllocator = new BlockAllocatorImpl(); + String[] schema = {"DATA_TYPE", "COLUMN_NAME", "PRECISION", "SCALE"}; + int[] types = {Types.VARCHAR, Types.VARCHAR, Types.INTEGER, Types.INTEGER}; + + Object[][] values = { + // VARCHAR group + {"varchar", "testCol1", 0, 0}, // varchar + {"char", "testCol2", 0, 0}, // char + {"binary", "testCol3", 0, 0}, // binary + {"nchar", "testCol4", 0, 0}, // nchar + {"nvarchar", "testCol5", 0, 0}, // nvarchar + {"varbinary", "testCol6", 0, 0}, // varbinary + {"time", "testCol7", 0, 0}, // time + {"uniqueidentifier", "testCol8", 0, 0}, // uniqueidentifier + + // Boolean + {"bit", "testCol9", 0, 0}, + + // Integer group + {"tinyint", "testCol10", 0, 0}, + {"smallint", "testCol11", 0, 0}, + {"int", "testCol12", 0, 0}, + {"bigint", "testCol13", 0, 0}, + + // Decimal + {"decimal", "testCol14", 10, 2}, + {"float", "testCol15", 0, 0}, // float + {"float", "testCol16", 0, 0}, // float + {"real", "testCol17", 0, 0}, // real + + // Dates + {"date", "testCol18", 0, 0}, + {"datetime", "testCol19", 0, 0}, // datetime/datetime2 + {"datetimeoffset", "testCol20", 0, 0} // datetimeoffset + }; + + ResultSet resultSet = mockResultSet(schema, types, values, new AtomicInteger(-1)); + + SchemaBuilder expectedSchemaBuilder = SchemaBuilder.newBuilder(); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol1", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol2", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol3", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol4", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol5", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol6", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol7", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol8", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build()); + + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol9", org.apache.arrow.vector.types.Types.MinorType.TINYINT.getType()).build()); + + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol10", org.apache.arrow.vector.types.Types.MinorType.SMALLINT.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol11", org.apache.arrow.vector.types.Types.MinorType.SMALLINT.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol12", org.apache.arrow.vector.types.Types.MinorType.INT.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol13", org.apache.arrow.vector.types.Types.MinorType.BIGINT.getType()).build()); + + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol14", new ArrowType.Decimal(10, 2, 256)).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol15", org.apache.arrow.vector.types.Types.MinorType.FLOAT8.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol16", org.apache.arrow.vector.types.Types.MinorType.FLOAT8.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol17", org.apache.arrow.vector.types.Types.MinorType.FLOAT4.getType()).build()); + + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol18", org.apache.arrow.vector.types.Types.MinorType.DATEDAY.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol19", org.apache.arrow.vector.types.Types.MinorType.DATEMILLI.getType()).build()); + expectedSchemaBuilder.addField(FieldBuilder.newBuilder("testCol20", org.apache.arrow.vector.types.Types.MinorType.DATEMILLI.getType()).build()); + + PARTITION_SCHEMA.getFields().forEach(expectedSchemaBuilder::addField); + Schema expected = expectedSchemaBuilder.build(); + + PreparedStatement stmt = mock(PreparedStatement.class); + when(connection.prepareStatement(nullable(String.class))).thenReturn(stmt); + when(stmt.executeQuery()).thenReturn(resultSet); + + when(connection.getMetaData().getURL()).thenReturn("jdbc:sqlserver://test-ondemand.sql.azuresynapse.net;databaseName=fakedatabase"); + + TableName inputTableName = new TableName(TEST_SCHEMA, TEST_TABLE); + when(connection.getCatalog()).thenReturn(TEST_CATALOG); + + GetTableResponse getTableResponse = this.synapseMetadataHandler.doGetTable( + blockAllocator, new GetTableRequest(this.federatedIdentity, TEST_QUERY_ID, TEST_CATALOG, inputTableName, Collections.emptyMap())); + + // Compare schemas ignoring order + assertTrue("Schemas do not match when ignoring order", schemasMatchIgnoringOrder(expected, getTableResponse.getSchema())); + assertEquals(inputTableName, getTableResponse.getTableName()); + assertEquals(TEST_CATALOG, getTableResponse.getCatalogName()); + + } + + // Helper method to compare schemas ignoring field order + private boolean schemasMatchIgnoringOrder(Schema expected, Schema actual) + { + if (expected == actual) { + return true; + } + if (expected == null || actual == null) { + return false; + } + List expectedFields = expected.getFields(); + List actualFields = actual.getFields(); + if (expectedFields.size() != actualFields.size()) { + return false; + } + + Map expectedFieldMap = expectedFields.stream() + .collect(Collectors.toMap( + Field::getName, + Field::getType, + (t1, t2) -> t1, // Merge function to handle duplicate keys (not expected here) + LinkedHashMap::new + )); + Map actualFieldMap = actualFields.stream() + .collect(Collectors.toMap( + Field::getName, + Field::getType, + (t1, t2) -> t1, + LinkedHashMap::new + )); + + return expectedFieldMap.equals(actualFieldMap); + } + + @Test + public void doGetDataSourceCapabilities_WithValidRequest_ReturnsCapabilities() + { + BlockAllocator allocator = new BlockAllocatorImpl(); + GetDataSourceCapabilitiesRequest request = + new GetDataSourceCapabilitiesRequest(federatedIdentity, TEST_QUERY_ID, TEST_CATALOG); + + GetDataSourceCapabilitiesResponse response = + synapseMetadataHandler.doGetDataSourceCapabilities(allocator, request); + + Map> capabilities = response.getCapabilities(); + + assertEquals(TEST_CATALOG, response.getCatalogName()); + + // Filter pushdown + List filterPushdown = capabilities.get("supports_filter_pushdown"); + assertNotNull("Expected supports_filter_pushdown capability to be present", filterPushdown); + assertEquals(2, filterPushdown.size()); + assertTrue(filterPushdown.stream().anyMatch(subType -> subType.getSubType().equals("sorted_range_set"))); + assertTrue(filterPushdown.stream().anyMatch(subType -> subType.getSubType().equals("nullable_comparison"))); + + // Complex expression pushdown + List complexPushdown = capabilities.get("supports_complex_expression_pushdown"); + assertNotNull("Expected supports_complex_expression_pushdown capability to be present", complexPushdown); + assertEquals(1, complexPushdown.size()); + OptimizationSubType complexSubType = complexPushdown.get(0); + assertEquals("supported_function_expression_types", complexSubType.getSubType()); + assertNotNull("Expected function expression types to be present", complexSubType.getProperties()); + assertFalse("Expected function expression types to be non-empty", complexSubType.getProperties().isEmpty()); + + // Top-N pushdown + List topNPushdown = capabilities.get("supports_top_n_pushdown"); + assertNotNull("Expected supports_top_n_pushdown capability to be present", topNPushdown); + assertEquals(1, topNPushdown.size()); + assertEquals("SUPPORTS_ORDER_BY", topNPushdown.get(0).getSubType()); + } } diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java index f4f4a0fe16..8d8f4768e9 100644 --- a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java @@ -19,17 +19,24 @@ */ package com.amazonaws.athena.connectors.synapse; +import com.amazonaws.athena.connector.credentials.CredentialsProvider; +import com.amazonaws.athena.connector.lambda.QueryStatusChecker; +import com.amazonaws.athena.connector.lambda.data.BlockSpiller; import com.amazonaws.athena.connector.lambda.data.FieldBuilder; import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; +import com.amazonaws.athena.connector.lambda.domain.predicate.Marker; +import com.amazonaws.athena.connector.lambda.domain.predicate.OrderByField; import com.amazonaws.athena.connector.lambda.domain.predicate.Range; import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet; import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; +import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException; +import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; +import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; -import com.amazonaws.athena.connector.credentials.CredentialsProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -43,16 +50,64 @@ import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import static com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature.SCHEMA_FUNCTION_NAME; +import static com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough.QUERY; import static com.amazonaws.athena.connectors.synapse.SynapseConstants.QUOTE_CHARACTER; -import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.nullable; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class SynapseRecordHandlerTest { + private static final String TEST_CATALOG = "testCatalog"; + private static final String TEST_CATALOG_NAME = "testCatalogName"; + private static final String TEST_SCHEMA = "testSchema"; + private static final String TEST_TABLE = "testTable"; + private static final String TEST_QUERY_ID = "testQueryId"; + private static final String TEST_DB_NAME = "fakedatabase"; + private static final String TEST_HOSTNAME = "hostname"; + private static final String TEST_JDBC_URL = String.format("synapse://jdbc:sqlserver://%s;databaseName=%s", TEST_HOSTNAME, TEST_DB_NAME); + private static final TableName TEST_TABLE_NAME = new TableName(TEST_SCHEMA, TEST_TABLE); + + // Test column names + private static final String TEST_COL1 = "testCol1"; + private static final String TEST_COL2 = "testCol2"; + private static final String TEST_COL3 = "testCol3"; + private static final String TEST_COL4 = "testCol4"; + private static final String TEST_ID_COL = "id"; + private static final String TEST_NAME_COL = "name"; + private static final String TEST_CREATED_AT_COL = "created_at"; + + // Test values + private static final String TEST_VARCHAR_VALUE = "varcharTest"; + private static final String TEST_PARTITION_FROM = "100000"; + private static final String TEST_PARTITION_TO = "300000"; + private static final int TEST_ID_1 = 123; + private static final int TEST_ID_2 = 124; + private static final String TEST_NAME_1 = "test1"; + private static final String TEST_NAME_2 = "test2"; + private static final String COL_ID = "id"; + private static final String COL_NAME = "name"; + private static final String COL_VALUE = "value"; + private static final String COL_INT = "intCol"; + private static final String COL_DOUBLE = "doubleCol"; + private static final String COL_STRING = "stringCol"; + private SynapseRecordHandler synapseRecordHandler; private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; @@ -60,6 +115,7 @@ public class SynapseRecordHandlerTest private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; + private FederatedIdentity federatedIdentity; @Before public void setup() @@ -70,103 +126,522 @@ public void setup() this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); - Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(CredentialsProvider.class))).thenReturn(this.connection); + this.federatedIdentity = mock(FederatedIdentity.class); + when(this.jdbcConnectionFactory.getConnection(nullable(CredentialsProvider.class))).thenReturn(this.connection); jdbcSplitQueryBuilder = new SynapseQueryStringBuilder(QUOTE_CHARACTER, new SynapseFederationExpressionParser(QUOTE_CHARACTER)); - final DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", SynapseConstants.NAME, - "synapse://jdbc:sqlserver://hostname;databaseName=fakedatabase"); + final DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig(TEST_CATALOG, SynapseConstants.NAME, + TEST_JDBC_URL); this.synapseRecordHandler = new SynapseRecordHandler(databaseConnectionConfig, amazonS3, secretsManager, athena, jdbcConnectionFactory, jdbcSplitQueryBuilder, com.google.common.collect.ImmutableMap.of()); } - private ValueSet getSingleValueSet(Object value) { - Range range = Mockito.mock(Range.class, Mockito.RETURNS_DEEP_STUBS); - Mockito.when(range.isSingleValue()).thenReturn(true); - Mockito.when(range.getLow().getValue()).thenReturn(value); - ValueSet valueSet = Mockito.mock(SortedRangeSet.class, Mockito.RETURNS_DEEP_STUBS); - Mockito.when(valueSet.getRanges().getOrderedRanges()).thenReturn(Collections.singletonList(range)); - return valueSet; - } - @Test public void buildSplitSql() throws SQLException { - TableName tableName = new TableName("testSchema", "testTable"); - - SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol1", Types.MinorType.INT.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol2", Types.MinorType.DATEDAY.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol3", Types.MinorType.DATEMILLI.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol4", Types.MinorType.VARCHAR.getType()).build()); - Schema schema = schemaBuilder.build(); + Schema schema = SchemaBuilder.newBuilder() + .addField(FieldBuilder.newBuilder(TEST_COL1, Types.MinorType.INT.getType()).build()) + .addField(FieldBuilder.newBuilder(TEST_COL2, Types.MinorType.DATEDAY.getType()).build()) + .addField(FieldBuilder.newBuilder(TEST_COL3, Types.MinorType.DATEMILLI.getType()).build()) + .addField(FieldBuilder.newBuilder(TEST_COL4, Types.MinorType.VARCHAR.getType()).build()) + .build(); Split split = Mockito.mock(Split.class); - Mockito.when(split.getProperty(SynapseMetadataHandler.PARTITION_COLUMN)).thenReturn("id"); - Mockito.when(split.getProperty(SynapseMetadataHandler.PARTITION_BOUNDARY_FROM)).thenReturn("100000"); - Mockito.when(split.getProperty(SynapseMetadataHandler.PARTITION_BOUNDARY_TO)).thenReturn("300000"); + when(split.getProperty(SynapseMetadataHandler.PARTITION_COLUMN)).thenReturn(TEST_ID_COL); + when(split.getProperty(SynapseMetadataHandler.PARTITION_BOUNDARY_FROM)).thenReturn(TEST_PARTITION_FROM); + when(split.getProperty(SynapseMetadataHandler.PARTITION_BOUNDARY_TO)).thenReturn(TEST_PARTITION_TO); - ValueSet valueSet = getSingleValueSet("varcharTest"); + ValueSet valueSet = getSingleValueSet(TEST_VARCHAR_VALUE); Constraints constraints = Mockito.mock(Constraints.class); - Mockito.when(constraints.getSummary()).thenReturn(new ImmutableMap.Builder() - .put("testCol4", valueSet) + when(constraints.getSummary()).thenReturn(new ImmutableMap.Builder() + .put(TEST_COL4, valueSet) .build()); - Mockito.when(constraints.getLimit()).thenReturn(5L); + when(constraints.getLimit()).thenReturn(5L); - String expectedSql = "SELECT \"testCol1\", \"testCol2\", \"testCol3\", \"testCol4\" FROM \"testSchema\".\"testTable\" WHERE (\"testCol4\" = ?) AND id > 100000 and id <= 300000"; - PreparedStatement expectedPreparedStatement = Mockito.mock(PreparedStatement.class); - Mockito.when(this.connection.prepareStatement(Mockito.eq(expectedSql))).thenReturn(expectedPreparedStatement); - PreparedStatement preparedStatement = this.synapseRecordHandler.buildSplitSql(this.connection, "testCatalogName", tableName, schema, constraints, split); + String expectedSql = "SELECT \"" + TEST_COL1 + "\", \"" + TEST_COL2 + "\", \"" + TEST_COL3 + "\", \"" + TEST_COL4 + "\" FROM \"" + TEST_SCHEMA + "\".\"" + TEST_TABLE + "\" WHERE (\"" + TEST_COL4 + "\" = ?) AND " + TEST_ID_COL + " > " + TEST_PARTITION_FROM + " and " + TEST_ID_COL + " <= " + TEST_PARTITION_TO; + PreparedStatement expectedPreparedStatement = createMockPreparedStatement(expectedSql); + PreparedStatement preparedStatement = this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG, TEST_TABLE_NAME, schema, constraints, split); Assert.assertEquals(expectedPreparedStatement, preparedStatement); - Mockito.verify(preparedStatement, Mockito.times(1)).setString(1, "varcharTest"); + verify(preparedStatement, Mockito.times(1)).setString(1, TEST_VARCHAR_VALUE); } @Test public void buildSplitSqlWithPartition() throws SQLException { - TableName tableName = new TableName("testSchema", "testTable"); - - SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol1", Types.MinorType.INT.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol2", Types.MinorType.DATEDAY.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol3", Types.MinorType.DATEMILLI.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol4", Types.MinorType.VARBINARY.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("partition", Types.MinorType.VARCHAR.getType()).build()); + Schema schema = SchemaBuilder.newBuilder() + .addField(FieldBuilder.newBuilder(TEST_COL1, Types.MinorType.INT.getType()).build()) + .addField(FieldBuilder.newBuilder(TEST_COL2, Types.MinorType.DATEDAY.getType()).build()) + .addField(FieldBuilder.newBuilder(TEST_COL3, Types.MinorType.DATEMILLI.getType()).build()) + .addField(FieldBuilder.newBuilder(TEST_COL4, Types.MinorType.VARBINARY.getType()).build()) + .addField(FieldBuilder.newBuilder("partition", Types.MinorType.VARCHAR.getType()).build()) + .build(); + + // Test case 1: Normal partition boundaries + Split split = mockSplitWithPartitionProperties("0", "100000", "1"); + Constraints constraints = Mockito.mock(Constraints.class); + PreparedStatement expectedPreparedStatement = Mockito.mock(PreparedStatement.class); + when(this.connection.prepareStatement(nullable(String.class))).thenReturn(expectedPreparedStatement); + this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG, TEST_TABLE_NAME, schema, constraints, split); + + // Test case 2: Empty from boundary + split = mockSplitWithPartitionProperties(" ", "100000", "1"); + this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG, TEST_TABLE_NAME, schema, constraints, split); + + // Test case 3: Empty to boundary + split = mockSplitWithPartitionProperties("300000", " ", "2"); + this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG, TEST_TABLE_NAME, schema, constraints, split); + + // Test case 4: Both boundaries empty + split = mockSplitWithPartitionProperties(" ", " ", "2"); + this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG, TEST_TABLE_NAME, schema, constraints, split); + } + + @Test + public void readWithConstraint_WithValidData_ProcessesRows() throws Exception { + Schema schema = SchemaBuilder.newBuilder() + .addField(FieldBuilder.newBuilder(TEST_ID_COL, Types.MinorType.INT.getType()).build()) + .addField(FieldBuilder.newBuilder(TEST_NAME_COL, Types.MinorType.VARCHAR.getType()).build()) + .addField(FieldBuilder.newBuilder(TEST_CREATED_AT_COL, Types.MinorType.DATEMILLI.getType()).build()) + .build(); + + Split split = Mockito.mock(Split.class); + when(split.getProperties()).thenReturn(Collections.emptyMap()); + + // Setup mock result set with actual test data + ResultSet resultSet = Mockito.mock(ResultSet.class); + when(resultSet.next()).thenReturn(true, true, false); // Return true twice for two rows, then false + when(resultSet.getInt(TEST_ID_COL)).thenReturn(TEST_ID_1, TEST_ID_2); + when(resultSet.getString(TEST_NAME_COL)).thenReturn(TEST_NAME_1, TEST_NAME_2); + when(resultSet.getTimestamp(TEST_CREATED_AT_COL)).thenReturn(new Timestamp(System.currentTimeMillis())); + + PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class); + when(connection.prepareStatement(Mockito.anyString())).thenReturn(preparedStatement); + when(preparedStatement.executeQuery()).thenReturn(resultSet); + + DatabaseMetaData metaData = Mockito.mock(DatabaseMetaData.class); + when(connection.getMetaData()).thenReturn(metaData); + when(metaData.getURL()).thenReturn("jdbc:sqlserver://test.sql.azuresynapse.net:1433;databaseName=testdb;"); + + ReadRecordsRequest request = new ReadRecordsRequest( + federatedIdentity, + TEST_CATALOG, + TEST_QUERY_ID, + TEST_TABLE_NAME, + schema, + split, + new Constraints(Collections.emptyMap(), Collections.emptyList(), Collections.emptyList(), 1000, Collections.emptyMap(),null), + 0, + 0 + ); + + BlockSpiller spiller = Mockito.mock(BlockSpiller.class); + QueryStatusChecker queryStatusChecker = mock(QueryStatusChecker.class); + when(queryStatusChecker.isQueryRunning()).thenReturn(true); + + // Execute the test + synapseRecordHandler.readWithConstraint(spiller, request, queryStatusChecker); + + // Verify that writeRows was called twice (once for each row) + verify(spiller, Mockito.times(2)).writeRows(Mockito.any()); + verify(resultSet, Mockito.times(3)).next(); // Called 3 times (2 true, 1 false) + + } + + @Test + public void buildSplitSql_WithOrderBy_ReturnsCorrectSql() throws SQLException { + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); + SchemaBuilder schemaBuilder = createSchemaWithCommonFields(); + schemaBuilder.addField(FieldBuilder.newBuilder(COL_VALUE, Types.MinorType.FLOAT8.getType()).build()); Schema schema = schemaBuilder.build(); + Split split = createMockSplit(); + + List orderByFields = new ArrayList<>(); + orderByFields.add(new OrderByField(COL_VALUE, OrderByField.Direction.DESC_NULLS_LAST)); + orderByFields.add(new OrderByField(COL_NAME, OrderByField.Direction.ASC_NULLS_LAST)); + + Constraints constraints = new Constraints( + Collections.emptyMap(), + Collections.emptyList(), + orderByFields, + Constraints.DEFAULT_NO_LIMIT, + Collections.emptyMap(), + null + ); + + String expectedSql = "SELECT \"id\", \"name\", \"value\" FROM \"testSchema\".\"testTable\" WHERE id > 100000 and id <= 300000 ORDER BY \"value\" DESC NULLS LAST, \"name\" ASC NULLS LAST"; + PreparedStatement expectedPreparedStatement = createMockPreparedStatement(expectedSql); + PreparedStatement preparedStatement = this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG_NAME, tableName, schema, constraints, split); + + Assert.assertEquals(expectedPreparedStatement, preparedStatement); + verifyFetchSize(expectedPreparedStatement); + } + + @Test + public void buildSplitSql_WithComplexExpressions_ReturnsCorrectSql() throws SQLException { + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); + SchemaBuilder schemaBuilder = createSchemaWithCommonFields(); + schemaBuilder.addField(FieldBuilder.newBuilder(COL_DOUBLE, Types.MinorType.FLOAT8.getType()).build()); + Schema schema = schemaBuilder.build(); + + Split split = createMockSplit(); + ValueSet nameValueSet = getRangeSet(Marker.Bound.EXACTLY, "test", Marker.Bound.BELOW, "tesu"); + ValueSet doubleValueSet = getRangeSet(Marker.Bound.EXACTLY, 1.0d, Marker.Bound.EXACTLY, 2.0d); + + Constraints constraints = new Constraints( + new ImmutableMap.Builder() + .put(COL_NAME, nameValueSet) + .put(COL_DOUBLE, doubleValueSet) + .build(), + Collections.emptyList(), + Collections.emptyList(), + Constraints.DEFAULT_NO_LIMIT, + Collections.emptyMap(), + null + ); + + String expectedSql = "SELECT \"id\", \"name\", \"doubleCol\" FROM \"testSchema\".\"testTable\" WHERE ((\"name\" >= ? AND \"name\" < ?)) AND ((\"doubleCol\" >= ? AND \"doubleCol\" <= ?)) AND id > 100000 and id <= 300000"; + PreparedStatement expectedPreparedStatement = createMockPreparedStatement(expectedSql); + + PreparedStatement preparedStatement = this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG_NAME, tableName, schema, constraints, split); + + Assert.assertEquals(expectedPreparedStatement, preparedStatement); + verifyFetchSize(expectedPreparedStatement); + Mockito.verify(preparedStatement, Mockito.times(1)).setString(1, "test"); + Mockito.verify(preparedStatement, Mockito.times(1)).setString(2, "tesu"); + Mockito.verify(preparedStatement, Mockito.times(1)).setDouble(3, 1.0d); + Mockito.verify(preparedStatement, Mockito.times(1)).setDouble(4, 2.0d); + } + + @Test + public void buildSplitSql_WithValueComparisons_ReturnsCorrectSql() throws SQLException { + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); + SchemaBuilder schemaBuilder = createSchemaWithCommonFields(); + schemaBuilder.addField(FieldBuilder.newBuilder(COL_INT, Types.MinorType.INT.getType()).build()); + Schema schema = schemaBuilder.build(); + Split split = createMockSplit(); + + ValueSet stringValueSet = getSingleValueSet("testValue"); + ValueSet intValueSet = getSingleValueSet(42); + + Map summary = new ImmutableMap.Builder() + .put(COL_NAME, stringValueSet) + .put(COL_INT, intValueSet) + .build(); + + Constraints constraints = new Constraints( + summary, + Collections.emptyList(), + Collections.emptyList(), + Constraints.DEFAULT_NO_LIMIT, + Collections.emptyMap(), + null); + + String expectedSql = "SELECT \"id\", \"name\", \"intCol\" FROM \"testSchema\".\"testTable\" WHERE (\"name\" = ?) AND (\"intCol\" = ?) AND id > 100000 and id <= 300000"; + PreparedStatement expectedPreparedStatement = createMockPreparedStatement(expectedSql); + + PreparedStatement preparedStatement = this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG_NAME, tableName, schema, constraints, split); + + Assert.assertEquals(expectedPreparedStatement, preparedStatement); + + Mockito.verify(preparedStatement, Mockito.times(1)).setString(1, "testValue"); + Mockito.verify(preparedStatement, Mockito.times(1)).setInt(2, 42); + verifyFetchSize(expectedPreparedStatement); + } + + @Test + public void buildSplitSql_WithEmptyConstraints_ReturnsCorrectSql() throws SQLException { + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); + Schema schema = createSchemaWithCommonFields().build(); + Split split = createMockSplit(); + Constraints constraints = new Constraints( + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + Constraints.DEFAULT_NO_LIMIT, + Collections.emptyMap(), + null + ); + + String expectedSql = "SELECT \"" + COL_ID + "\", \"" + COL_NAME + "\" FROM \"testSchema\".\"testTable\" WHERE id > 100000 and id <= 300000"; + PreparedStatement preparedStatement = createMockPreparedStatement(expectedSql); + + PreparedStatement result = this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG, tableName, schema, constraints, split); + + Assert.assertEquals(preparedStatement, result); + verifyFetchSize(preparedStatement); + } + + @Test + public void buildSplitSql_WithLimitOffset_ReturnsCorrectSql() throws SQLException { + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); + Schema schema = createSchemaWithValueField().build(); + Split split = createMockSplit(); + + Constraints constraints = new Constraints( + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + 5L, + Collections.emptyMap(), + null + ); + + // Expected SQL should NOT contain LIMIT clause as Synapse does not support LIMIT clause + String expectedSql = "SELECT \"id\", \"name\", \"value\" FROM \"testSchema\".\"testTable\" WHERE id > 100000 and id <= 300000"; + PreparedStatement expectedPreparedStatement = createMockPreparedStatement(expectedSql); + + PreparedStatement preparedStatement = this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG_NAME, tableName, schema, constraints, split); + + Assert.assertEquals(expectedPreparedStatement, preparedStatement); + verifyFetchSize(expectedPreparedStatement); + } + + @Test + public void buildSplitSql_WithRangeAndInPredicates_ReturnsCorrectSql() throws SQLException { + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); + SchemaBuilder schemaBuilder = createSchemaWithCommonFields(); + schemaBuilder.addField(FieldBuilder.newBuilder(COL_INT, Types.MinorType.INT.getType()).build()); + schemaBuilder.addField(FieldBuilder.newBuilder(COL_DOUBLE, Types.MinorType.FLOAT8.getType()).build()); + schemaBuilder.addField(FieldBuilder.newBuilder(COL_STRING, Types.MinorType.VARCHAR.getType()).build()); + Schema schema = schemaBuilder.build(); + + Split split = createMockSplit(); + + ValueSet intValueSet = getSingleValueSet(Arrays.asList(1, 2, 3)); + ValueSet doubleValueSet = getRangeSet(Marker.Bound.EXACTLY, 1.5d, Marker.Bound.BELOW, 5.5d); + ValueSet stringValueSet = getSingleValueSet(Arrays.asList("value1", "value2")); + + Map summary = new ImmutableMap.Builder() + .put(COL_INT, intValueSet) + .put(COL_DOUBLE, doubleValueSet) + .put(COL_STRING, stringValueSet) + .build(); + + Constraints constraints = new Constraints( + summary, + Collections.emptyList(), + Collections.emptyList(), + Constraints.DEFAULT_NO_LIMIT, + Collections.emptyMap(), + null + ); + + String expectedSql = "SELECT \"id\", \"name\", \"intCol\", \"doubleCol\", \"stringCol\" FROM \"testSchema\".\"testTable\" WHERE (\"intCol\" IN (?,?,?)) AND ((\"doubleCol\" >= ? AND \"doubleCol\" < ?)) AND (\"stringCol\" IN (?,?)) AND id > 100000 and id <= 300000"; + PreparedStatement expectedPreparedStatement = createMockPreparedStatement(expectedSql); + + PreparedStatement preparedStatement = this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG_NAME, tableName, schema, constraints, split); + + Assert.assertEquals(expectedPreparedStatement, preparedStatement); + verifyFetchSize(expectedPreparedStatement); + + Mockito.verify(preparedStatement, Mockito.times(1)).setInt(1, 1); + Mockito.verify(preparedStatement, Mockito.times(1)).setInt(2, 2); + Mockito.verify(preparedStatement, Mockito.times(1)).setInt(3, 3); + Mockito.verify(preparedStatement, Mockito.times(1)).setDouble(4, 1.5d); + Mockito.verify(preparedStatement, Mockito.times(1)).setDouble(5, 5.5d); + Mockito.verify(preparedStatement, Mockito.times(1)).setString(6, "value1"); + Mockito.verify(preparedStatement, Mockito.times(1)).setString(7, "value2"); + } + + @Test + public void buildSplitSql_WithQueryPassthrough_ReturnsCorrectSql() throws SQLException { + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); + Schema schema = createSchemaWithCommonFields().build(); + + Split split = createMockSplit(); + + String passthroughQuery = "SELECT * FROM testSchema.testTable WHERE id > 100"; + Map passthroughArgs = new HashMap<>(); + passthroughArgs.put(QUERY, passthroughQuery); + passthroughArgs.put(SCHEMA_FUNCTION_NAME, "SYSTEM.QUERY"); + Constraints constraints = new Constraints( + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + Constraints.DEFAULT_NO_LIMIT, + passthroughArgs, + null); + + PreparedStatement expectedPreparedStatement = createMockPreparedStatement(passthroughQuery); + + PreparedStatement result = this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG, tableName, schema, constraints, split); + + Assert.assertEquals(expectedPreparedStatement, result); + verifyFetchSize(expectedPreparedStatement); + } + + @Test(expected = AthenaConnectorException.class) + public void buildSplitSql_WithInvalidQueryPassthrough_ThrowsAthenaConnectorException() throws SQLException { + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); + SchemaBuilder schemaBuilder = createSchemaWithCommonFields(); + schemaBuilder.addField(FieldBuilder.newBuilder(COL_ID, Types.MinorType.INT.getType()).build()); + Schema schema = schemaBuilder.build(); + + Split split = createMockSplit(); + + Map passthroughArgs = new HashMap<>(); + passthroughArgs.put(QUERY, "SELECT * FROM table"); + Constraints constraints = new Constraints( + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + Constraints.DEFAULT_NO_LIMIT, + passthroughArgs, + null); + + synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG, tableName, schema, constraints, split); + } + + @Test + public void buildSplitSql_WithComplexConstraintsAndOrderBy_ReturnsCorrectSql() throws SQLException { + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); + SchemaBuilder schemaBuilder = createSchemaWithValueField(); + schemaBuilder.addField(FieldBuilder.newBuilder("dateCol", Types.MinorType.DATEDAY.getType()).build()); + Schema schema = schemaBuilder.build(); + Split split = createMockSplit(); + + ValueSet valueRangeSet = getRangeSet(Marker.Bound.ABOVE, 10.0, Marker.Bound.EXACTLY, 100.0); + ValueSet nameValueSet = getSingleValueSet("testName"); + + List orderByFields = new ArrayList<>(); + orderByFields.add(new OrderByField(COL_ID, OrderByField.Direction.ASC_NULLS_LAST)); + orderByFields.add(new OrderByField(COL_VALUE, OrderByField.Direction.DESC_NULLS_LAST)); + + Constraints constraints = new Constraints( + new ImmutableMap.Builder() + .put(COL_VALUE, valueRangeSet) + .put(COL_NAME, nameValueSet) + .build(), + Collections.emptyList(), + orderByFields, + Constraints.DEFAULT_NO_LIMIT, + Collections.emptyMap(), + null + ); + + String expectedSql = "SELECT \"id\", \"name\", \"value\", \"dateCol\" FROM \"testSchema\".\"testTable\" WHERE (\"name\" = ?) AND ((\"value\" > ? AND \"value\" <= ?)) AND id > 100000 and id <= 300000 ORDER BY \"id\" ASC NULLS LAST, \"value\" DESC NULLS LAST" ; + PreparedStatement expectedPreparedStatement = createMockPreparedStatement(expectedSql); + + PreparedStatement preparedStatement = this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG_NAME, tableName, schema, constraints, split); + + Assert.assertEquals(expectedPreparedStatement, preparedStatement); + verifyFetchSize(expectedPreparedStatement); + Mockito.verify(preparedStatement, Mockito.times(1)).setString(1, "testName"); + Mockito.verify(preparedStatement, Mockito.times(1)).setDouble(2, 10.0); + Mockito.verify(preparedStatement, Mockito.times(1)).setDouble(3, 100.0); + } + + @Test + public void buildSplitSql_WithEmptyConstraintsAndOrderBy_ReturnsCorrectSql() throws SQLException { + TableName tableName = new TableName(TEST_SCHEMA, TEST_TABLE); + Schema schema = createSchemaWithCommonFields().build(); + Split split = createMockSplit(); + + List orderByFields = new ArrayList<>(); + orderByFields.add(new OrderByField(COL_ID, OrderByField.Direction.ASC_NULLS_LAST)); + orderByFields.add(new OrderByField(COL_NAME, OrderByField.Direction.DESC_NULLS_LAST)); + + Constraints constraints = new Constraints( + Collections.emptyMap(), + Collections.emptyList(), + orderByFields, + Constraints.DEFAULT_NO_LIMIT, + Collections.emptyMap(), + null + ); + + String expectedSql = "SELECT \"" + COL_ID + "\", \"" + COL_NAME + "\" FROM \"testSchema\".\"testTable\" WHERE id > 100000 and id <= 300000 ORDER BY \"" + COL_ID + "\" ASC NULLS LAST, \"" + COL_NAME + "\" DESC NULLS LAST"; + PreparedStatement preparedStatement = createMockPreparedStatement(expectedSql); + + PreparedStatement result = this.synapseRecordHandler.buildSplitSql(this.connection, TEST_CATALOG, tableName, schema, constraints, split); + + Assert.assertEquals(preparedStatement, result); + verifyFetchSize(preparedStatement); + } + + + private Split mockSplitWithPartitionProperties(String from, String to, String partitionNumber) + { Split split = Mockito.mock(Split.class); - Mockito.when(split.getProperties()).thenReturn(com.google.common.collect.ImmutableMap.of("PARTITION_BOUNDARY_FROM", "0", "PARTITION_NUMBER", "1", "PARTITION_COLUMN", "testCol1", "PARTITION_BOUNDARY_TO", "100000")); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_BOUNDARY_FROM"))).thenReturn("0"); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_NUMBER"))).thenReturn("1"); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_COLUMN"))).thenReturn("testCol1"); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_BOUNDARY_TO"))).thenReturn("100000"); + ImmutableMap properties = ImmutableMap.of( + SynapseMetadataHandler.PARTITION_BOUNDARY_FROM, from, + SynapseMetadataHandler.PARTITION_NUMBER, partitionNumber, + SynapseMetadataHandler.PARTITION_COLUMN, TEST_COL1, + SynapseMetadataHandler.PARTITION_BOUNDARY_TO, to + ); + when(split.getProperties()).thenReturn(properties); + when(split.getProperty(eq(SynapseMetadataHandler.PARTITION_BOUNDARY_FROM))).thenReturn(from); + when(split.getProperty(eq(SynapseMetadataHandler.PARTITION_NUMBER))).thenReturn(partitionNumber); + when(split.getProperty(eq(SynapseMetadataHandler.PARTITION_COLUMN))).thenReturn(TEST_COL1); + when(split.getProperty(eq(SynapseMetadataHandler.PARTITION_BOUNDARY_TO))).thenReturn(to); + return split; + } - Constraints constraints = Mockito.mock(Constraints.class); + private ValueSet getRangeSet(Marker.Bound lowerBound, Object lowerValue, Marker.Bound upperBound, Object upperValue) { + Range range = Mockito.mock(Range.class, Mockito.RETURNS_DEEP_STUBS); + Mockito.when(range.isSingleValue()).thenReturn(false); + Mockito.when(range.getLow().getBound()).thenReturn(lowerBound); + Mockito.when(range.getLow().getValue()).thenReturn(lowerValue); + Mockito.when(range.getHigh().getBound()).thenReturn(upperBound); + Mockito.when(range.getHigh().getValue()).thenReturn(upperValue); + ValueSet valueSet = Mockito.mock(SortedRangeSet.class, Mockito.RETURNS_DEEP_STUBS); + Mockito.when(valueSet.getRanges().getOrderedRanges()).thenReturn(Collections.singletonList(range)); + return valueSet; + } + + private SchemaBuilder createSchemaWithCommonFields() { + return SchemaBuilder.newBuilder() + .addField(FieldBuilder.newBuilder(COL_ID, Types.MinorType.INT.getType()).build()) + .addField(FieldBuilder.newBuilder(COL_NAME, Types.MinorType.VARCHAR.getType()).build()); + } + + private Split createMockSplit() { + Split split = Mockito.mock(Split.class); + Mockito.when(split.getProperty(SynapseMetadataHandler.PARTITION_COLUMN)).thenReturn(COL_ID); + Mockito.when(split.getProperty(SynapseMetadataHandler.PARTITION_BOUNDARY_FROM)).thenReturn("100000"); + Mockito.when(split.getProperty(SynapseMetadataHandler.PARTITION_BOUNDARY_TO)).thenReturn("300000"); + return split; + } + + private PreparedStatement createMockPreparedStatement(String expectedSql) throws SQLException { PreparedStatement expectedPreparedStatement = Mockito.mock(PreparedStatement.class); - Mockito.when(this.connection.prepareStatement(nullable(String.class))).thenReturn(expectedPreparedStatement); - this.synapseRecordHandler.buildSplitSql(this.connection, "testCatalogName", tableName, schema, constraints, split); - - Mockito.when(split.getProperties()).thenReturn(com.google.common.collect.ImmutableMap.of("PARTITION_BOUNDARY_FROM", " ", "PARTITION_NUMBER", "1", "PARTITION_COLUMN", "testCol1", "PARTITION_BOUNDARY_TO", "100000")); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_BOUNDARY_FROM"))).thenReturn(" "); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_NUMBER"))).thenReturn("1"); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_COLUMN"))).thenReturn("testCol1"); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_BOUNDARY_TO"))).thenReturn("100000"); - this.synapseRecordHandler.buildSplitSql(this.connection, "testCatalogName", tableName, schema, constraints, split); - - Mockito.when(split.getProperties()).thenReturn(com.google.common.collect.ImmutableMap.of("PARTITION_BOUNDARY_FROM", "300000", "PARTITION_NUMBER", "2", "PARTITION_COLUMN", "testCol1", "PARTITION_BOUNDARY_TO", " ")); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_BOUNDARY_FROM"))).thenReturn("300000"); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_NUMBER"))).thenReturn("1"); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_COLUMN"))).thenReturn("testCol1"); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_BOUNDARY_TO"))).thenReturn(" "); - this.synapseRecordHandler.buildSplitSql(this.connection, "testCatalogName", tableName, schema, constraints, split); - - Mockito.when(split.getProperties()).thenReturn(com.google.common.collect.ImmutableMap.of("PARTITION_BOUNDARY_FROM", " ", "PARTITION_NUMBER", "2", "PARTITION_COLUMN", "testCol1", "PARTITION_BOUNDARY_TO", " ")); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_BOUNDARY_FROM"))).thenReturn(" "); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_NUMBER"))).thenReturn("1"); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_COLUMN"))).thenReturn("testCol1"); - Mockito.when(split.getProperty(Mockito.eq("PARTITION_BOUNDARY_TO"))).thenReturn(" "); - this.synapseRecordHandler.buildSplitSql(this.connection, "testCatalogName", tableName, schema, constraints, split); + Mockito.when(this.connection.prepareStatement(Mockito.eq(expectedSql))).thenReturn(expectedPreparedStatement); + return expectedPreparedStatement; + } + + private void verifyFetchSize(PreparedStatement preparedStatement) throws SQLException { + Mockito.verify(preparedStatement, Mockito.atLeastOnce()).setFetchSize(1000); + } + + private ValueSet getSingleValueSet(Object value) { + Range range = Mockito.mock(Range.class, Mockito.RETURNS_DEEP_STUBS); + Mockito.when(range.isSingleValue()).thenReturn(true); + Mockito.when(range.getLow().getValue()).thenReturn(value); + ValueSet valueSet = Mockito.mock(SortedRangeSet.class, Mockito.RETURNS_DEEP_STUBS); + Mockito.when(valueSet.getRanges().getOrderedRanges()).thenReturn(Collections.singletonList(range)); + return valueSet; + } + + private ValueSet getSingleValueSet(List values) { + List ranges = values.stream().map(value -> { + Range range = Mockito.mock(Range.class, Mockito.RETURNS_DEEP_STUBS); + Mockito.when(range.isSingleValue()).thenReturn(true); + Mockito.when(range.getLow().getValue()).thenReturn(value); + return range; + }).collect(Collectors.toList()); + + ValueSet valueSet = Mockito.mock(SortedRangeSet.class, Mockito.RETURNS_DEEP_STUBS); + Mockito.when(valueSet.getRanges().getOrderedRanges()).thenReturn(ranges); + return valueSet; + } + + + private SchemaBuilder createSchemaWithValueField() { + return createSchemaWithCommonFields() + .addField(FieldBuilder.newBuilder(COL_VALUE, Types.MinorType.FLOAT8.getType()).build()); } } diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/resolver/SynapseJDBCCaseResolverTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/resolver/SynapseJDBCCaseResolverTest.java new file mode 100644 index 0000000000..819aff9404 --- /dev/null +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/resolver/SynapseJDBCCaseResolverTest.java @@ -0,0 +1,117 @@ +/*- + * #%L + * athena-synapse + * %% + * Copyright (C) 2019 - 2025 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.synapse.resolver; + +import com.amazonaws.athena.connector.lambda.domain.TableName; +import com.amazonaws.athena.connector.lambda.resolver.CaseResolver; +import com.amazonaws.athena.connectors.jdbc.TestBase; +import com.amazonaws.athena.connectors.jdbc.resolver.DefaultJDBCCaseResolver; +import com.amazonaws.athena.connectors.synapse.SynapseConstants; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.amazonaws.athena.connector.lambda.resolver.CaseResolver.CASING_MODE_CONFIGURATION_KEY; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +public class SynapseJDBCCaseResolverTest extends TestBase +{ + private static final String TEST_SCHEMA_NAME = "oRaNgE"; + private static final String TEST_TABLE_NAME = "ApPlE"; + private static final String SCHEMA_COLUMN = "SCHEMA_NAME"; + private static final String TABLE_COLUMN = "TABLE_NAME"; + private static final String CASE_INSENSITIVE_MODE = CaseResolver.FederationSDKCasingMode.CASE_INSENSITIVE_SEARCH.name(); + + private Connection mockConnection; + private PreparedStatement preparedStatement; + + @Before + public void setup() throws SQLException + { + mockConnection = Mockito.mock(Connection.class); + preparedStatement = Mockito.mock(PreparedStatement.class); + when(mockConnection.prepareStatement(any())).thenReturn(preparedStatement); + } + + @Test + public void getAdjustedSchemaAndTableName_withCaseInsensitiveMode_tableNameIsUpperCased() throws SQLException { + DefaultJDBCCaseResolver resolver = new SynapseJDBCCaseResolver(SynapseConstants.NAME); + + // Mock schema name result + ResultSet schemaResultSet = mockResultSet( + new String[]{SCHEMA_COLUMN}, + new int[]{Types.VARCHAR}, + new Object[][]{{TEST_SCHEMA_NAME.toLowerCase()}}, + new AtomicInteger(-1)); + when(preparedStatement.executeQuery()).thenReturn(schemaResultSet); + + String adjustedSchemaName = resolver.getAdjustedSchemaNameString(mockConnection, TEST_SCHEMA_NAME, Map.of( + CASING_MODE_CONFIGURATION_KEY, CASE_INSENSITIVE_MODE)); + assertEquals(TEST_SCHEMA_NAME.toLowerCase(), adjustedSchemaName); + + // Mock table name result + ResultSet tableResultSet = mockResultSet( + new String[]{TABLE_COLUMN}, + new int[]{Types.VARCHAR}, + new Object[][]{{TEST_TABLE_NAME.toUpperCase()}}, + new AtomicInteger(-1)); + when(preparedStatement.executeQuery()).thenReturn(tableResultSet); + + String adjustedTableName = resolver.getAdjustedTableNameString(mockConnection, TEST_SCHEMA_NAME, TEST_TABLE_NAME, Map.of( + CASING_MODE_CONFIGURATION_KEY, CASE_INSENSITIVE_MODE)); + assertEquals(TEST_TABLE_NAME.toUpperCase(), adjustedTableName); + } + + @Test + public void getAdjustedTableNameObject_withCaseInsensitiveMode_schemaNameIsLowerCased() throws SQLException { + DefaultJDBCCaseResolver resolver = new SynapseJDBCCaseResolver(SynapseConstants.NAME); + + // Mock schema and table result sets + ResultSet schemaResultSet = mockResultSet( + new String[]{SCHEMA_COLUMN}, + new int[]{Types.VARCHAR}, + new Object[][]{{TEST_SCHEMA_NAME.toLowerCase()}}, + new AtomicInteger(-1)); + + ResultSet tableResultSet = mockResultSet( + new String[]{TABLE_COLUMN}, + new int[]{Types.VARCHAR}, + new Object[][]{{TEST_TABLE_NAME.toUpperCase()}}, + new AtomicInteger(-1)); + + when(preparedStatement.executeQuery()).thenReturn(schemaResultSet).thenReturn(tableResultSet); + + TableName adjusted = resolver.getAdjustedTableNameObject( + mockConnection, + new TableName(TEST_SCHEMA_NAME, TEST_TABLE_NAME), + Map.of(CASING_MODE_CONFIGURATION_KEY, CASE_INSENSITIVE_MODE)); + assertEquals(new TableName(TEST_SCHEMA_NAME.toLowerCase(), TEST_TABLE_NAME.toUpperCase()), adjusted); + } +} \ No newline at end of file