Skip to content

Commit 0c51f57

Browse files
Add serverless datalakegen2 support (#2973)
Co-authored-by: burhan94 <[email protected]>
1 parent a3eca13 commit 0c51f57

File tree

7 files changed

+598
-25
lines changed

7 files changed

+598
-25
lines changed

athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2Constants.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public class DataLakeGen2Constants
2525
public static final String DRIVER_CLASS = "com.microsoft.sqlserver.jdbc.SQLServerDriver";
2626
public static final int DEFAULT_PORT = 1433;
2727
public static final String QUOTE_CHARACTER = "\"";
28+
public static final String SQL_POOL = "azureServerless";
2829

2930
private DataLakeGen2Constants() {}
3031
}

athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandler.java

Lines changed: 136 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -221,45 +221,133 @@ protected Schema getSchema(Connection jdbcConnection, TableName tableName, Schem
221221
{
222222
LOGGER.info("Inside getSchema");
223223

224-
String dataTypeQuery = "SELECT C.NAME AS COLUMN_NAME, TYPE_NAME(C.USER_TYPE_ID) AS DATA_TYPE " +
224+
String dataTypeQuery = "SELECT C.NAME AS COLUMN_NAME, TYPE_NAME(C.USER_TYPE_ID) AS DATA_TYPE, " +
225+
"C.PRECISION, C.SCALE " +
225226
"FROM sys.columns C " +
226227
"JOIN sys.types T " +
227228
"ON C.USER_TYPE_ID=T.USER_TYPE_ID " +
228229
"WHERE C.OBJECT_ID=OBJECT_ID(?)";
229230

230231
String dataType;
231232
String columnName;
232-
HashMap<String, String> hashMap = new HashMap<>();
233-
boolean found = false;
233+
int precision;
234+
int scale;
235+
HashMap<String, ColumnInfo> hashMap = new HashMap<>();
234236

235237
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
236-
try (ResultSet resultSet = getColumns(jdbcConnection.getCatalog(), tableName, jdbcConnection.getMetaData());
237-
Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider());
238+
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider());
238239
PreparedStatement stmt = connection.prepareStatement(dataTypeQuery)) {
239240
// fetch data types of columns and prepare map with column name and datatype.
240241
stmt.setString(1, tableName.getSchemaName() + "." + tableName.getTableName());
241242
try (ResultSet dataTypeResultSet = stmt.executeQuery()) {
242243
while (dataTypeResultSet.next()) {
243244
dataType = dataTypeResultSet.getString("DATA_TYPE");
244245
columnName = dataTypeResultSet.getString("COLUMN_NAME");
245-
hashMap.put(columnName.trim(), dataType.trim());
246+
precision = dataTypeResultSet.getInt("PRECISION");
247+
scale = dataTypeResultSet.getInt("SCALE");
248+
hashMap.put(columnName.trim(), new ColumnInfo(dataType.trim(), precision, scale));
246249
}
247250
}
251+
}
252+
253+
String environment = DataLakeGen2Util.checkEnvironment(jdbcConnection.getMetaData().getURL());
254+
255+
if (DataLakeGen2Constants.SQL_POOL.equalsIgnoreCase(environment)) {
256+
// getColumns() method from SQL Server driver is causing an exception in case of Azure Serverless environment.
257+
// so doing explicit data type conversion
258+
schemaBuilder = doDataTypeConversion(hashMap);
259+
}
260+
else {
261+
schemaBuilder = doDataTypeConversionForNonCompatible(jdbcConnection, tableName, hashMap);
262+
}
263+
// add partition columns
264+
partitionSchema.getFields().forEach(schemaBuilder::addField);
265+
return schemaBuilder.build();
266+
}
248267

268+
private SchemaBuilder doDataTypeConversion(HashMap<String, ColumnInfo> columnNameAndDataTypeMap)
269+
{
270+
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
271+
272+
for (Map.Entry<String, ColumnInfo> entry : columnNameAndDataTypeMap.entrySet()) {
273+
String columnName = entry.getKey();
274+
ColumnInfo columnInfo = entry.getValue();
275+
String dataType = columnInfo.getDataType();
276+
ArrowType columnType = Types.MinorType.VARCHAR.getType();
277+
278+
if ("char".equalsIgnoreCase(dataType) || "varchar".equalsIgnoreCase(dataType) ||
279+
"nchar".equalsIgnoreCase(dataType) || "nvarchar".equalsIgnoreCase(dataType)
280+
|| "time".equalsIgnoreCase(dataType) || "uniqueidentifier".equalsIgnoreCase(dataType)) {
281+
columnType = Types.MinorType.VARCHAR.getType();
282+
}
283+
284+
if ("binary".equalsIgnoreCase(dataType) || "varbinary".equalsIgnoreCase(dataType)) {
285+
columnType = Types.MinorType.VARBINARY.getType();
286+
}
287+
288+
if ("bit".equalsIgnoreCase(dataType)) {
289+
columnType = Types.MinorType.BIT.getType();
290+
}
291+
292+
if ("tinyint".equalsIgnoreCase(dataType)) {
293+
columnType = Types.MinorType.TINYINT.getType();
294+
}
295+
296+
if ("smallint".equalsIgnoreCase(dataType)) {
297+
columnType = Types.MinorType.SMALLINT.getType();
298+
}
299+
300+
if ("int".equalsIgnoreCase(dataType)) {
301+
columnType = Types.MinorType.INT.getType();
302+
}
303+
304+
if ("bigint".equalsIgnoreCase(dataType)) {
305+
columnType = Types.MinorType.BIGINT.getType();
306+
}
307+
308+
if ("decimal".equalsIgnoreCase(dataType)) {
309+
columnType = new ArrowType.Decimal(columnInfo.getPrecision(), columnInfo.getScale(), 128);
310+
}
311+
312+
if ("numeric".equalsIgnoreCase(dataType) || "float".equalsIgnoreCase(dataType) || "smallmoney".equalsIgnoreCase(dataType) || "money".equalsIgnoreCase(dataType)) {
313+
columnType = Types.MinorType.FLOAT8.getType();
314+
}
315+
316+
if ("real".equalsIgnoreCase(dataType)) {
317+
columnType = Types.MinorType.FLOAT4.getType();
318+
}
319+
320+
if ("date".equalsIgnoreCase(dataType)) {
321+
columnType = Types.MinorType.DATEDAY.getType();
322+
}
323+
324+
if ("datetime".equalsIgnoreCase(dataType) || "datetime2".equalsIgnoreCase(dataType)
325+
|| "smalldatetime".equalsIgnoreCase(dataType) || "datetimeoffset".equalsIgnoreCase(dataType)) {
326+
columnType = Types.MinorType.DATEMILLI.getType();
327+
}
328+
329+
schemaBuilder.addField(FieldBuilder.newBuilder(columnName, columnType).build());
330+
}
331+
return schemaBuilder;
332+
}
333+
334+
private SchemaBuilder doDataTypeConversionForNonCompatible(Connection jdbcConnection, TableName tableName, HashMap<String, ColumnInfo> columnNameAndDataTypeMap) throws SQLException
335+
{
336+
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
337+
338+
try (ResultSet resultSet = getColumns(jdbcConnection.getCatalog(), tableName, jdbcConnection.getMetaData())) {
339+
boolean found = false;
249340
while (resultSet.next()) {
250341
Optional<ArrowType> columnType = JdbcArrowTypeConverter.toArrowType(
251342
resultSet.getInt("DATA_TYPE"),
252343
resultSet.getInt("COLUMN_SIZE"),
253344
resultSet.getInt("DECIMAL_DIGITS"),
254345
configOptions);
255-
columnName = resultSet.getString("COLUMN_NAME");
346+
String columnName = resultSet.getString("COLUMN_NAME");
347+
ColumnInfo columnInfo = columnNameAndDataTypeMap.get(columnName);
256348

257-
dataType = hashMap.get(columnName);
258-
LOGGER.debug("columnName: " + columnName);
259-
LOGGER.debug("dataType: " + dataType);
260-
261-
if (dataType != null && DataLakeGen2DataType.isSupported(dataType)) {
262-
columnType = Optional.of(DataLakeGen2DataType.fromType(dataType));
349+
if (columnInfo != null && DataLakeGen2DataType.isSupported(columnInfo.getDataType())) {
350+
columnType = Optional.of(DataLakeGen2DataType.fromType(columnInfo.getDataType()));
263351
}
264352

265353
/**
@@ -269,31 +357,58 @@ protected Schema getSchema(Connection jdbcConnection, TableName tableName, Schem
269357
columnType = Optional.of(Types.MinorType.VARCHAR.getType());
270358
}
271359

272-
LOGGER.debug("columnType: " + columnType);
273360
if (columnType.isPresent() && SupportedTypes.isSupported(columnType.get())) {
274361
schemaBuilder.addField(FieldBuilder.newBuilder(columnName, columnType.get()).build());
275362
found = true;
276363
}
277364
else {
278-
LOGGER.error("getSchema: Unable to map type for column[" + columnName + "] to a supported type, attempted " + columnType);
365+
LOGGER.error("getSchema: Unable to map type for column[{}] to a supported type, attempted {}", columnName, columnType);
279366
}
280367
}
281368
if (!found) {
369+
LOGGER.error("Could not find any supported columns in table: {}.{}", tableName.getSchemaName(), tableName.getTableName());
282370
throw new RuntimeException("Could not find table in " + tableName.getSchemaName());
283371
}
284-
285-
partitionSchema.getFields().forEach(schemaBuilder::addField);
286-
return schemaBuilder.build();
287372
}
373+
return schemaBuilder;
288374
}
289375

290376
@Override
291377
protected CredentialsProvider getCredentialProvider()
292378
{
293379
return CredentialsProviderFactory.createCredentialProvider(
294-
getDatabaseConnectionConfig().getSecret(),
295-
getCachableSecretsManager(),
296-
new DataLakeGen2OAuthCredentialsProvider()
380+
getDatabaseConnectionConfig().getSecret(),
381+
getCachableSecretsManager(),
382+
new DataLakeGen2OAuthCredentialsProvider()
297383
);
298384
}
299385
}
386+
387+
class ColumnInfo
388+
{
389+
private final String dataType;
390+
private final int precision;
391+
private final int scale;
392+
393+
public ColumnInfo(String dataType, int precision, int scale)
394+
{
395+
this.dataType = dataType;
396+
this.precision = precision;
397+
this.scale = scale;
398+
}
399+
400+
public String getDataType()
401+
{
402+
return dataType;
403+
}
404+
405+
public int getPrecision()
406+
{
407+
return precision;
408+
}
409+
410+
public int getScale()
411+
{
412+
return scale;
413+
}
414+
}

athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@
2020
package com.amazonaws.athena.connectors.datalakegen2;
2121
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
2222
import com.amazonaws.athena.connector.credentials.CredentialsProviderFactory;
23+
import com.amazonaws.athena.connector.lambda.QueryStatusChecker;
24+
import com.amazonaws.athena.connector.lambda.data.Block;
25+
import com.amazonaws.athena.connector.lambda.data.BlockSpiller;
26+
import com.amazonaws.athena.connector.lambda.data.writers.GeneratedRowWriter;
2327
import com.amazonaws.athena.connector.lambda.domain.Split;
2428
import com.amazonaws.athena.connector.lambda.domain.TableName;
2529
import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
30+
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
2631
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
2732
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
2833
import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory;
@@ -31,20 +36,27 @@
3136
import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler;
3237
import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder;
3338
import com.google.common.annotations.VisibleForTesting;
39+
import org.apache.arrow.vector.types.pojo.ArrowType;
40+
import org.apache.arrow.vector.types.pojo.Field;
3441
import org.apache.arrow.vector.types.pojo.Schema;
3542
import org.apache.commons.lang3.Validate;
43+
import org.slf4j.Logger;
44+
import org.slf4j.LoggerFactory;
3645
import software.amazon.awssdk.services.athena.AthenaClient;
3746
import software.amazon.awssdk.services.s3.S3Client;
3847
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
3948

4049
import java.sql.Connection;
4150
import java.sql.PreparedStatement;
51+
import java.sql.ResultSet;
4252
import java.sql.SQLException;
53+
import java.util.Map;
4354

4455
import static com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2Constants.QUOTE_CHARACTER;
4556

4657
public class DataLakeGen2RecordHandler extends JdbcRecordHandler
4758
{
59+
private static final Logger LOGGER = LoggerFactory.getLogger(DataLakeGen2RecordHandler.class);
4860
private static final int FETCH_SIZE = 1000;
4961
private final JdbcSplitQueryBuilder jdbcSplitQueryBuilder;
5062
public DataLakeGen2RecordHandler(java.util.Map<String, String> configOptions)
@@ -88,4 +100,56 @@ protected CredentialsProvider getCredentialProvider()
88100
new DataLakeGen2OAuthCredentialsProvider()
89101
);
90102
}
103+
104+
@Override
105+
public void readWithConstraint(BlockSpiller blockSpiller, ReadRecordsRequest readRecordsRequest, QueryStatusChecker queryStatusChecker)
106+
throws Exception
107+
{
108+
LOGGER.info("{}: Catalog: {}, table {}, splits {}", readRecordsRequest.getQueryId(), readRecordsRequest.getCatalogName(), readRecordsRequest.getTableName(),
109+
readRecordsRequest.getSplit().getProperties());
110+
111+
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
112+
String environment = DataLakeGen2Util.checkEnvironment(connection.getMetaData().getURL());
113+
if (!DataLakeGen2Constants.SQL_POOL.equalsIgnoreCase(environment)) {
114+
// For consistency. This is needed to be false to enable streaming for some database types.
115+
// But in Azure serverless, this causes @@TRANCOUNT errors during connection cleanup
116+
connection.setAutoCommit(false);
117+
}
118+
try (PreparedStatement preparedStatement = buildSplitSql(connection, readRecordsRequest.getCatalogName(), readRecordsRequest.getTableName(),
119+
readRecordsRequest.getSchema(), readRecordsRequest.getConstraints(), readRecordsRequest.getSplit());
120+
ResultSet resultSet = preparedStatement.executeQuery()) {
121+
Map<String, String> partitionValues = readRecordsRequest.getSplit().getProperties();
122+
123+
GeneratedRowWriter.RowWriterBuilder rowWriterBuilder = GeneratedRowWriter.newBuilder(readRecordsRequest.getConstraints());
124+
for (Field next : readRecordsRequest.getSchema().getFields()) {
125+
if (next.getType() instanceof ArrowType.List) {
126+
rowWriterBuilder.withFieldWriterFactory(next.getName(), makeFactory(next));
127+
}
128+
else {
129+
rowWriterBuilder.withExtractor(next.getName(), makeExtractor(next, resultSet, partitionValues));
130+
}
131+
}
132+
133+
GeneratedRowWriter rowWriter = rowWriterBuilder.build();
134+
int rowsReturnedFromDatabase = 0;
135+
while (resultSet.next()) {
136+
if (!queryStatusChecker.isQueryRunning()) {
137+
return;
138+
}
139+
blockSpiller.writeRows((Block block, int rowNum) -> rowWriter.writeRow(block, rowNum, resultSet) ? 1 : 0);
140+
rowsReturnedFromDatabase++;
141+
}
142+
LOGGER.info("{} rows returned by database.", rowsReturnedFromDatabase);
143+
144+
/*
145+
SqlServer jdbc driver is using @@TRANCOUNT while performing commit(), it results below RuntimeException.
146+
com.microsoft.sqlserver.jdbc.SQLServerException: '@@TRANCOUNT' is not supported.
147+
So we are evading this connection.commit(), in case of Azure serverless environment.
148+
*/
149+
if (!DataLakeGen2Constants.SQL_POOL.equalsIgnoreCase(environment)) {
150+
connection.commit();
151+
}
152+
}
153+
}
154+
}
91155
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*-
2+
* #%L
3+
* athena-datalakegen2
4+
* %%
5+
* Copyright (C) 2019 - 2025 Amazon Web Services
6+
* %%
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* #L%
19+
*/
20+
package com.amazonaws.athena.connectors.datalakegen2;
21+
22+
import org.apache.commons.lang3.StringUtils;
23+
import org.slf4j.Logger;
24+
import org.slf4j.LoggerFactory;
25+
26+
import java.util.regex.Matcher;
27+
import java.util.regex.Pattern;
28+
29+
import static com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2Constants.SQL_POOL;
30+
31+
public class DataLakeGen2Util
32+
{
33+
private static final Logger LOGGER = LoggerFactory.getLogger(DataLakeGen2Util.class);
34+
35+
private DataLakeGen2Util()
36+
{
37+
}
38+
39+
private static final Pattern DATALAKE_CONN_STRING_PATTERN = Pattern.compile("([a-zA-Z]+)://([^;]+);(.*)");
40+
41+
public static String checkEnvironment(String url)
42+
{
43+
if (StringUtils.isBlank(url)) {
44+
return null;
45+
}
46+
47+
// checking whether it's Azure serverless environment or not based on host name
48+
Matcher m = DATALAKE_CONN_STRING_PATTERN.matcher(url);
49+
String hostName = "";
50+
if (m.find() && m.groupCount() == 3) {
51+
hostName = m.group(2);
52+
}
53+
54+
if (StringUtils.isNotBlank(hostName) && hostName.contains("ondemand")) {
55+
LOGGER.info("Azure serverless environment detected");
56+
return SQL_POOL;
57+
}
58+
59+
return null;
60+
}
61+
}

0 commit comments

Comments
 (0)