Skip to content

Commit 796669d

Browse files
committed
add serverless datalakegen2 support
1 parent 3e9e1e1 commit 796669d

File tree

7 files changed

+546
-17
lines changed

7 files changed

+546
-17
lines changed

athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2Constants.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public class DataLakeGen2Constants
2525
public static final String DRIVER_CLASS = "com.microsoft.sqlserver.jdbc.SQLServerDriver";
2626
public static final int DEFAULT_PORT = 1433;
2727
public static final String QUOTE_CHARACTER = "\"";
28+
public static final String SQL_POOL = "azureServerless";
2829

2930
private DataLakeGen2Constants() {}
3031
}

athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandler.java

Lines changed: 85 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,9 @@ protected Schema getSchema(Connection jdbcConnection, TableName tableName, Schem
227227
String dataType;
228228
String columnName;
229229
HashMap<String, String> hashMap = new HashMap<>();
230-
boolean found = false;
231230

232231
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
233-
try (ResultSet resultSet = getColumns(jdbcConnection.getCatalog(), tableName, jdbcConnection.getMetaData());
234-
Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider());
232+
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider());
235233
PreparedStatement stmt = connection.prepareStatement(dataTypeQuery)) {
236234
// fetch data types of columns and prepare map with column name and datatype.
237235
stmt.setString(1, tableName.getSchemaName() + "." + tableName.getTableName());
@@ -242,18 +240,94 @@ protected Schema getSchema(Connection jdbcConnection, TableName tableName, Schem
242240
hashMap.put(columnName.trim(), dataType.trim());
243241
}
244242
}
243+
}
244+
245+
String environment = DataLakeGen2Util.checkEnvironment(jdbcConnection.getMetaData().getURL());
246+
247+
if (DataLakeGen2Constants.SQL_POOL.equalsIgnoreCase(environment)) {
248+
// getColumns() method from SQL Server driver is causing an exception in case of Azure Serverless environment.
249+
// so doing explicit data type conversion
250+
schemaBuilder = doDataTypeConversion(hashMap);
251+
}
252+
else {
253+
schemaBuilder = doDataTypeConversionForNonCompatible(jdbcConnection, tableName, hashMap);
254+
}
255+
// add partition columns
256+
partitionSchema.getFields().forEach(schemaBuilder::addField);
257+
return schemaBuilder.build();
258+
}
259+
260+
private SchemaBuilder doDataTypeConversion(HashMap<String, String> columnNameAndDataTypeMap)
261+
{
262+
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
263+
264+
for (Map.Entry<String, String> entry : columnNameAndDataTypeMap.entrySet()) {
265+
String columnName = entry.getKey();
266+
String dataType = entry.getValue();
267+
ArrowType columnType = Types.MinorType.VARCHAR.getType();
268+
269+
if ("char".equalsIgnoreCase(dataType) || "varchar".equalsIgnoreCase(dataType) || "binary".equalsIgnoreCase(dataType) ||
270+
"nchar".equalsIgnoreCase(dataType) || "nvarchar".equalsIgnoreCase(dataType) || "varbinary".equalsIgnoreCase(dataType)
271+
|| "time".equalsIgnoreCase(dataType) || "uniqueidentifier".equalsIgnoreCase(dataType)) {
272+
columnType = Types.MinorType.VARCHAR.getType();
273+
}
274+
275+
if ("bit".equalsIgnoreCase(dataType)) {
276+
columnType = Types.MinorType.TINYINT.getType();
277+
}
278+
279+
if ("tinyint".equalsIgnoreCase(dataType) || "smallint".equalsIgnoreCase(dataType)) {
280+
columnType = Types.MinorType.SMALLINT.getType();
281+
}
282+
283+
if ("int".equalsIgnoreCase(dataType)) {
284+
columnType = Types.MinorType.INT.getType();
285+
}
286+
287+
if ("bigint".equalsIgnoreCase(dataType)) {
288+
columnType = Types.MinorType.BIGINT.getType();
289+
}
290+
291+
if ("decimal".equalsIgnoreCase(dataType) || "money".equalsIgnoreCase(dataType)) {
292+
columnType = Types.MinorType.FLOAT8.getType();
293+
}
294+
295+
if ("numeric".equalsIgnoreCase(dataType) || "float".equalsIgnoreCase(dataType) || "smallmoney".equalsIgnoreCase(dataType)) {
296+
columnType = Types.MinorType.FLOAT8.getType();
297+
}
245298

299+
if ("real".equalsIgnoreCase(dataType)) {
300+
columnType = Types.MinorType.FLOAT4.getType();
301+
}
302+
303+
if ("date".equalsIgnoreCase(dataType)) {
304+
columnType = Types.MinorType.DATEDAY.getType();
305+
}
306+
307+
if ("datetime".equalsIgnoreCase(dataType) || "datetime2".equalsIgnoreCase(dataType)
308+
|| "smalldatetime".equalsIgnoreCase(dataType) || "datetimeoffset".equalsIgnoreCase(dataType)) {
309+
columnType = Types.MinorType.DATEMILLI.getType();
310+
}
311+
312+
schemaBuilder.addField(FieldBuilder.newBuilder(columnName, columnType).build());
313+
}
314+
return schemaBuilder;
315+
}
316+
317+
private SchemaBuilder doDataTypeConversionForNonCompatible(Connection jdbcConnection, TableName tableName, HashMap<String, String> columnNameAndDataTypeMap) throws SQLException
318+
{
319+
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
320+
321+
try (ResultSet resultSet = getColumns(jdbcConnection.getCatalog(), tableName, jdbcConnection.getMetaData())) {
322+
boolean found = false;
246323
while (resultSet.next()) {
247324
Optional<ArrowType> columnType = JdbcArrowTypeConverter.toArrowType(
248325
resultSet.getInt("DATA_TYPE"),
249326
resultSet.getInt("COLUMN_SIZE"),
250327
resultSet.getInt("DECIMAL_DIGITS"),
251328
configOptions);
252-
columnName = resultSet.getString("COLUMN_NAME");
253-
254-
dataType = hashMap.get(columnName);
255-
LOGGER.debug("columnName: " + columnName);
256-
LOGGER.debug("dataType: " + dataType);
329+
String columnName = resultSet.getString("COLUMN_NAME");
330+
String dataType = columnNameAndDataTypeMap.get(columnName);
257331

258332
if (dataType != null && DataLakeGen2DataType.isSupported(dataType)) {
259333
columnType = Optional.of(DataLakeGen2DataType.fromType(dataType));
@@ -266,21 +340,19 @@ protected Schema getSchema(Connection jdbcConnection, TableName tableName, Schem
266340
columnType = Optional.of(Types.MinorType.VARCHAR.getType());
267341
}
268342

269-
LOGGER.debug("columnType: " + columnType);
270343
if (columnType.isPresent() && SupportedTypes.isSupported(columnType.get())) {
271344
schemaBuilder.addField(FieldBuilder.newBuilder(columnName, columnType.get()).build());
272345
found = true;
273346
}
274347
else {
275-
LOGGER.error("getSchema: Unable to map type for column[" + columnName + "] to a supported type, attempted " + columnType);
348+
LOGGER.error("getSchema: Unable to map type for column[{}] to a supported type, attempted {}", columnName, columnType);
276349
}
277350
}
278351
if (!found) {
352+
LOGGER.error("Could not find any supported columns in table: {}.{}", tableName.getSchemaName(), tableName.getTableName());
279353
throw new RuntimeException("Could not find table in " + tableName.getSchemaName());
280354
}
281-
282-
partitionSchema.getFields().forEach(schemaBuilder::addField);
283-
return schemaBuilder.build();
284355
}
356+
return schemaBuilder;
285357
}
286358
}

athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@
1818
* #L%
1919
*/
2020
package com.amazonaws.athena.connectors.datalakegen2;
21+
import com.amazonaws.athena.connector.lambda.QueryStatusChecker;
22+
import com.amazonaws.athena.connector.lambda.data.Block;
23+
import com.amazonaws.athena.connector.lambda.data.BlockSpiller;
24+
import com.amazonaws.athena.connector.lambda.data.writers.GeneratedRowWriter;
2125
import com.amazonaws.athena.connector.lambda.domain.Split;
2226
import com.amazonaws.athena.connector.lambda.domain.TableName;
2327
import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
28+
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
2429
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
2530
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
2631
import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory;
@@ -29,20 +34,27 @@
2934
import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler;
3035
import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder;
3136
import com.google.common.annotations.VisibleForTesting;
37+
import org.apache.arrow.vector.types.pojo.ArrowType;
38+
import org.apache.arrow.vector.types.pojo.Field;
3239
import org.apache.arrow.vector.types.pojo.Schema;
3340
import org.apache.commons.lang3.Validate;
41+
import org.slf4j.Logger;
42+
import org.slf4j.LoggerFactory;
3443
import software.amazon.awssdk.services.athena.AthenaClient;
3544
import software.amazon.awssdk.services.s3.S3Client;
3645
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
3746

3847
import java.sql.Connection;
3948
import java.sql.PreparedStatement;
49+
import java.sql.ResultSet;
4050
import java.sql.SQLException;
51+
import java.util.Map;
4152

4253
import static com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2Constants.QUOTE_CHARACTER;
4354

4455
public class DataLakeGen2RecordHandler extends JdbcRecordHandler
4556
{
57+
private static final Logger LOGGER = LoggerFactory.getLogger(DataLakeGen2RecordHandler.class);
4658
private static final int FETCH_SIZE = 1000;
4759
private final JdbcSplitQueryBuilder jdbcSplitQueryBuilder;
4860
public DataLakeGen2RecordHandler(java.util.Map<String, String> configOptions)
@@ -76,4 +88,56 @@ public PreparedStatement buildSplitSql(Connection jdbcConnection, String catalog
7688
preparedStatement.setFetchSize(FETCH_SIZE);
7789
return preparedStatement;
7890
}
91+
92+
@Override
93+
public void readWithConstraint(BlockSpiller blockSpiller, ReadRecordsRequest readRecordsRequest, QueryStatusChecker queryStatusChecker)
94+
throws Exception
95+
{
96+
LOGGER.info("{}: Catalog: {}, table {}, splits {}", readRecordsRequest.getQueryId(), readRecordsRequest.getCatalogName(), readRecordsRequest.getTableName(),
97+
readRecordsRequest.getSplit().getProperties());
98+
99+
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
100+
String environment = DataLakeGen2Util.checkEnvironment(connection.getMetaData().getURL());
101+
if (!DataLakeGen2Constants.SQL_POOL.equalsIgnoreCase(environment)) {
102+
// For consistency. This is needed to be false to enable streaming for some database types.
103+
// But in Azure serverless, this causes @@TRANCOUNT errors during connection cleanup
104+
connection.setAutoCommit(false);
105+
}
106+
try (PreparedStatement preparedStatement = buildSplitSql(connection, readRecordsRequest.getCatalogName(), readRecordsRequest.getTableName(),
107+
readRecordsRequest.getSchema(), readRecordsRequest.getConstraints(), readRecordsRequest.getSplit());
108+
ResultSet resultSet = preparedStatement.executeQuery()) {
109+
Map<String, String> partitionValues = readRecordsRequest.getSplit().getProperties();
110+
111+
GeneratedRowWriter.RowWriterBuilder rowWriterBuilder = GeneratedRowWriter.newBuilder(readRecordsRequest.getConstraints());
112+
for (Field next : readRecordsRequest.getSchema().getFields()) {
113+
if (next.getType() instanceof ArrowType.List) {
114+
rowWriterBuilder.withFieldWriterFactory(next.getName(), makeFactory(next));
115+
}
116+
else {
117+
rowWriterBuilder.withExtractor(next.getName(), makeExtractor(next, resultSet, partitionValues));
118+
}
119+
}
120+
121+
GeneratedRowWriter rowWriter = rowWriterBuilder.build();
122+
int rowsReturnedFromDatabase = 0;
123+
while (resultSet.next()) {
124+
if (!queryStatusChecker.isQueryRunning()) {
125+
return;
126+
}
127+
blockSpiller.writeRows((Block block, int rowNum) -> rowWriter.writeRow(block, rowNum, resultSet) ? 1 : 0);
128+
rowsReturnedFromDatabase++;
129+
}
130+
LOGGER.info("{} rows returned by database.", rowsReturnedFromDatabase);
131+
132+
/*
133+
SqlServer jdbc driver is using @@TRANCOUNT while performing commit(), it results below RuntimeException.
134+
com.microsoft.sqlserver.jdbc.SQLServerException: '@@TRANCOUNT' is not supported.
135+
So we are evading this connection.commit(), in case of Azure serverless environment.
136+
*/
137+
if (!DataLakeGen2Constants.SQL_POOL.equalsIgnoreCase(environment)) {
138+
connection.commit();
139+
}
140+
}
141+
}
142+
}
79143
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*-
2+
* #%L
3+
* athena-datalakegen2
4+
* %%
5+
* Copyright (C) 2019 - 2025 Amazon Web Services
6+
* %%
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* #L%
19+
*/
20+
package com.amazonaws.athena.connectors.datalakegen2;
21+
22+
import org.apache.commons.lang3.StringUtils;
23+
import org.slf4j.Logger;
24+
import org.slf4j.LoggerFactory;
25+
26+
import java.util.regex.Matcher;
27+
import java.util.regex.Pattern;
28+
29+
import static com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2Constants.SQL_POOL;
30+
31+
public class DataLakeGen2Util
32+
{
33+
private static final Logger LOGGER = LoggerFactory.getLogger(DataLakeGen2Util.class);
34+
35+
private DataLakeGen2Util()
36+
{
37+
}
38+
39+
private static final Pattern DATALAKE_CONN_STRING_PATTERN = Pattern.compile("([a-zA-Z]+)://([^;]+);(.*)");
40+
41+
public static String checkEnvironment(String url)
42+
{
43+
if (StringUtils.isBlank(url)) {
44+
return null;
45+
}
46+
47+
// checking whether it's Azure serverless environment or not based on host name
48+
Matcher m = DATALAKE_CONN_STRING_PATTERN.matcher(url);
49+
String hostName = "";
50+
if (m.find() && m.groupCount() == 3) {
51+
hostName = m.group(2);
52+
}
53+
54+
if (StringUtils.isNotBlank(hostName) && hostName.contains("ondemand")) {
55+
LOGGER.info("Azure serverless environment detected");
56+
return SQL_POOL;
57+
}
58+
59+
return null;
60+
}
61+
}

0 commit comments

Comments
 (0)