Skip to content

Commit a3e1dda

Browse files
committed
snowflake connector performance improvement s3 export
1 parent 8803ffd commit a3e1dda

File tree

4 files changed

+117
-23
lines changed

4 files changed

+117
-23
lines changed

athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request
227227
String randomStr = UUID.randomUUID().toString();
228228
String queryID = request.getQueryId().replace("-", "").concat(randomStr);
229229
String catalog = request.getCatalogName();
230-
String integrationName = catalog.concat(s3ExportBucket).concat("_integration").replaceAll("-", "_");
230+
String integrationName = catalog.concat(s3ExportBucket).concat("_integration").replaceAll("-", "_").replaceAll(":", "");
231231
LOGGER.debug("Integration Name {}", integrationName);
232232
//Build the SQL query
233233
Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider());
@@ -359,7 +359,7 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest
359359
}
360360
}
361361
catch (Exception throwables) {
362-
throw new RuntimeException("Exception in execution copy statement {}", throwables);
362+
throw new RuntimeException("Exception in execution export statement " + throwables.getMessage(), throwables);
363363
}
364364
}
365365

athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilder.java

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,22 @@
3737
import org.slf4j.Logger;
3838
import org.slf4j.LoggerFactory;
3939

40+
import java.math.BigDecimal;
4041
import java.sql.Connection;
4142
import java.sql.SQLException;
43+
import java.sql.Timestamp;
44+
import java.text.SimpleDateFormat;
45+
import java.time.LocalDateTime;
46+
import java.time.format.DateTimeFormatter;
4247
import java.util.ArrayList;
4348
import java.util.Collections;
49+
import java.util.Date;
4450
import java.util.List;
51+
import java.util.concurrent.TimeUnit;
4552
import java.util.stream.Collectors;
4653

54+
import static org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID.Utf8;
55+
4756
/**
4857
* Extends {@link JdbcSplitQueryBuilder} and implements MySql specific SQL clauses for split.
4958
*
@@ -52,9 +61,9 @@
5261
public class SnowflakeQueryStringBuilder
5362
extends JdbcSplitQueryBuilder
5463
{
55-
private static final String EMPTY_STRING = "";
5664
private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeQueryStringBuilder.class);
57-
private final String quoteCharacters = "\"";
65+
private static final String quoteCharacters = "\"";
66+
private static final String singleQuoteCharacters = "\'";
5867

5968
public SnowflakeQueryStringBuilder(final String quoteCharacters, final FederationExpressionParser federationExpressionParser)
6069
{
@@ -136,6 +145,12 @@ protected String quote(String name)
136145
return quoteCharacters + name + quoteCharacters;
137146
}
138147

148+
protected String singleQuote(String name)
149+
{
150+
name = name.replace(singleQuoteCharacters, singleQuoteCharacters + singleQuoteCharacters);
151+
return singleQuoteCharacters + name + singleQuoteCharacters;
152+
}
153+
139154
private List<String> toConjuncts(List<Field> columns, Constraints constraints, List<TypeAndValue> accumulator)
140155
{
141156
List<String> conjuncts = new ArrayList<>();
@@ -181,10 +196,10 @@ private String toPredicate(String columnName, ValueSet valueSet, ArrowType type,
181196
if (!range.getLow().isLowerUnbounded()) {
182197
switch (range.getLow().getBound()) {
183198
case ABOVE:
184-
rangeConjuncts.add(toPredicate(columnName, ">", range.getLow().getValue(), type, accumulator));
199+
rangeConjuncts.add(toPredicate(columnName, ">", range.getLow().getValue(), type));
185200
break;
186201
case EXACTLY:
187-
rangeConjuncts.add(toPredicate(columnName, ">=", range.getLow().getValue(), type, accumulator));
202+
rangeConjuncts.add(toPredicate(columnName, ">=", range.getLow().getValue(), type));
188203
break;
189204
case BELOW:
190205
throw new IllegalArgumentException("Low marker should never use BELOW bound");
@@ -197,10 +212,10 @@ private String toPredicate(String columnName, ValueSet valueSet, ArrowType type,
197212
case ABOVE:
198213
throw new IllegalArgumentException("High marker should never use ABOVE bound");
199214
case EXACTLY:
200-
rangeConjuncts.add(toPredicate(columnName, "<=", range.getHigh().getValue(), type, accumulator));
215+
rangeConjuncts.add(toPredicate(columnName, "<=", range.getHigh().getValue(), type));
201216
break;
202217
case BELOW:
203-
rangeConjuncts.add(toPredicate(columnName, "<", range.getHigh().getValue(), type, accumulator));
218+
rangeConjuncts.add(toPredicate(columnName, "<", range.getHigh().getValue(), type));
204219
break;
205220
default:
206221
throw new AssertionError("Unhandled bound: " + range.getHigh().getBound());
@@ -214,17 +229,86 @@ private String toPredicate(String columnName, ValueSet valueSet, ArrowType type,
214229

215230
// Add back all of the possible single values either as an equality or an IN predicate
216231
if (singleValues.size() == 1) {
217-
disjuncts.add(toPredicate(columnName, "=", Iterables.getOnlyElement(singleValues), type, accumulator));
232+
disjuncts.add(toPredicate(columnName, "=", Iterables.getOnlyElement(singleValues), type));
218233
}
219234
else if (singleValues.size() > 1) {
235+
List<Object> val = new ArrayList<>();
220236
for (Object value : singleValues) {
221-
accumulator.add(new TypeAndValue(type, value));
237+
val.add(((type.getTypeID().equals(Utf8) || type.getTypeID().equals(ArrowType.ArrowTypeID.Date)) ? singleQuote(getObjectForWhereClause(columnName, value, type).toString()) : getObjectForWhereClause(columnName, value, type)));
222238
}
223-
String values = Joiner.on(",").join(Collections.nCopies(singleValues.size(), "?"));
224-
disjuncts.add(quote(columnName) + " IN (" + values + ")");
239+
String values = Joiner.on(",").join(val);
240+
disjuncts.add(columnName + " IN (" + values + ")");
225241
}
226242
}
227-
228243
return "(" + Joiner.on(" OR ").join(disjuncts) + ")";
229244
}
245+
246+
protected String toPredicate(String columnName, String operator, Object value, ArrowType type)
247+
{
248+
return columnName + " " + operator + " " + ((type.getTypeID().equals(Utf8) || type.getTypeID().equals(ArrowType.ArrowTypeID.Date)) ? singleQuote(getObjectForWhereClause(columnName, value, type).toString()) : getObjectForWhereClause(columnName, value, type));
249+
}
250+
251+
private static Object getObjectForWhereClause(String columnName, Object value, ArrowType arrowType)
252+
{
253+
String val;
254+
StringBuilder tempVal;
255+
256+
switch (arrowType.getTypeID()) {
257+
case Int:
258+
return ((Number) value).longValue();
259+
case Decimal:
260+
if (value instanceof BigDecimal) {
261+
return (BigDecimal) value;
262+
}
263+
else if (value instanceof Number) {
264+
return BigDecimal.valueOf(((Number) value).doubleValue());
265+
}
266+
else {
267+
throw new IllegalArgumentException("Unexpected type for decimal conversion: " + value.getClass().getName());
268+
}
269+
case FloatingPoint:
270+
return (double) value;
271+
case Bool:
272+
return (Boolean) value;
273+
case Utf8:
274+
return value.toString();
275+
case Date:
276+
val = value.toString();
277+
if (val.contains("-") && val.length() == 16) {
278+
LocalDateTime dateTime = LocalDateTime.parse(val);
279+
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
280+
return dateTime.format(formatter);
281+
}
282+
else if (val.contains("-")) {
283+
tempVal = new StringBuilder(val);
284+
tempVal = tempVal.length() == 19 ? tempVal.append(".0") : tempVal;
285+
val = String.format("%-26s", tempVal).replace(' ', '0').replace("T", " ");
286+
return val; // Returning as string formatted datetime
287+
}
288+
else {
289+
long days = Long.parseLong(val);
290+
long milliseconds = TimeUnit.DAYS.toMillis(days);
291+
return new SimpleDateFormat("yyyy-MM-dd").format(new Date(milliseconds));
292+
}
293+
case Timestamp:
294+
long millis = ((Number) value).longValue();
295+
Timestamp timestamp = new Timestamp(millis);
296+
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
297+
return sdf.format(timestamp);
298+
case Time:
299+
case Interval:
300+
case Binary:
301+
case FixedSizeBinary:
302+
case Null:
303+
case Struct:
304+
case List:
305+
case FixedSizeList:
306+
case Union:
307+
case NONE:
308+
throw new UnsupportedOperationException("The Arrow type: " + arrowType.getTypeID().name() + " is currently not supported");
309+
default:
310+
throw new IllegalArgumentException("Unknown type encountered during processing: " + columnName +
311+
" Field Type: " + arrowType.getTypeID().name());
312+
}
313+
}
230314
}

athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import com.amazonaws.athena.connector.lambda.QueryStatusChecker;
2424
import com.amazonaws.athena.connector.lambda.data.Block;
2525
import com.amazonaws.athena.connector.lambda.data.BlockSpiller;
26-
import com.amazonaws.athena.connector.lambda.data.BlockUtils;
2726
import com.amazonaws.athena.connector.lambda.data.writers.GeneratedRowWriter;
2827
import com.amazonaws.athena.connector.lambda.data.writers.extractors.BigIntExtractor;
2928
import com.amazonaws.athena.connector.lambda.data.writers.extractors.BitExtractor;
@@ -80,8 +79,9 @@
8079

8180
import java.io.IOException;
8281
import java.math.BigDecimal;
83-
import java.time.LocalDate;
8482
import java.time.LocalDateTime;
83+
import java.time.format.DateTimeFormatter;
84+
import java.time.format.DateTimeParseException;
8585
import java.util.HashMap;
8686

8787
import static com.amazonaws.athena.connectors.snowflake.SnowflakeConstants.SNOWFLAKE_NAME;
@@ -178,7 +178,7 @@ public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsR
178178
}
179179
}
180180
catch (Exception e) {
181-
throw new RuntimeException("Error in connecting to S3 and selecting the object content for object : " + s3ObjectKey, e);
181+
throw new RuntimeException("Error in object content for object : " + s3ObjectKey, e);
182182
}
183183
}
184184
}
@@ -284,20 +284,19 @@ private Extractor makeExtractor(Field field, HashMap<String, Types.MinorType> ma
284284
dst.isSet = 0;
285285
}
286286
else {
287+
dst.value = (int) value;
287288
dst.isSet = 1;
288-
dst.value = (int) LocalDate.parse(value.toString()).toEpochDay();
289289
}
290290
};
291-
292291
case DATEMILLI:
293292
return (DateMilliExtractor) (Object context, NullableDateMilliHolder dst) ->
294293
{
295-
Object value = ((RowContext) context).getNameValue().get(fieldName).toString();
294+
Object value = ((RowContext) context).getNameValue().get(fieldName);
296295
if (value == null) {
297296
dst.isSet = 0;
298297
}
299298
else {
300-
dst.value = LocalDateTime.parse(value.toString()).atZone(BlockUtils.UTC_ZONE_ID).toInstant().toEpochMilli();
299+
dst.value = (long) value;
301300
dst.isSet = 1;
302301
}
303302
};
@@ -309,14 +308,25 @@ private Extractor makeExtractor(Field field, HashMap<String, Types.MinorType> ma
309308
dst.isSet = 0;
310309
}
311310
else {
312-
dst.value = value.toString();
311+
DateTimeFormatter inputFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm");
312+
DateTimeFormatter outputFormatter = DateTimeFormatter.ofPattern("HH:mm:ss");
313+
try {
314+
// Try parsing the input as a datetime string
315+
LocalDateTime dateTime = LocalDateTime.parse(value.toString(), inputFormatter);
316+
// If successful, return formatted time
317+
dst.value = dateTime.toLocalTime().format(outputFormatter);
318+
}
319+
catch (DateTimeParseException e) {
320+
// If parsing fails, return input as is
321+
dst.value = value.toString();
322+
}
313323
dst.isSet = 1;
314324
}
315325
};
316326
case VARBINARY:
317327
return (VarBinaryExtractor) (Object context, NullableVarBinaryHolder dst) ->
318328
{
319-
Object value = ((RowContext) context).getNameValue().get(fieldName).toString();
329+
Object value = ((RowContext) context).getNameValue().get(fieldName);
320330
if (value == null) {
321331
dst.isSet = 0;
322332
}

athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ public void getPartitions() throws Exception
268268
Block partitions = res.getPartitions();
269269

270270
String actualQueryID = partitions.getFieldReader("queryId").readText().toString();
271-
String expectedExportSql = "COPY INTO 's3://testS3Bucket/snowflake_data/" + actualQueryID + "/' FROM (SELECT \"day\", \"month\", \"year\", \"preparedStmt\", \"queryId\" FROM \"schema1\".\"table1\" WHERE ((\"day\" > ?)) AND ((\"month\" > ?)) AND ((\"year\" > ?))) STORAGE_INTEGRATION = defaulttestS3Bucket_integration HEADER = TRUE FILE_FORMAT = (TYPE = 'PARQUET', COMPRESSION = 'SNAPPY') MAX_FILE_SIZE = 16777216";
271+
String expectedExportSql = "COPY INTO 's3://testS3Bucket/snowflake_data/" + actualQueryID + "/' FROM (SELECT \"day\", \"month\", \"year\", \"preparedStmt\", \"queryId\" FROM \"schema1\".\"table1\" WHERE ((day > 0)) AND ((month > 0)) AND ((year > 2000))) STORAGE_INTEGRATION = defaulttestS3Bucket_integration HEADER = TRUE FILE_FORMAT = (TYPE = 'PARQUET', COMPRESSION = 'SNAPPY') MAX_FILE_SIZE = 16777216";
272272

273273
Assert.assertEquals(expectedExportSql, partitions.getFieldReader("preparedStmt").readText().toString());
274274

0 commit comments

Comments
 (0)