Skip to content

Commit a332ef5

Browse files
VenkatasivareddyTRTrianz-Akshay
authored andcommitted
Infer schema from gremlin/sparql query results for Neptune qpt (awslabs#2543)
Co-authored-by: akshay.kachore <[email protected]>
1 parent 865c596 commit a332ef5

File tree

11 files changed

+973
-122
lines changed

11 files changed

+973
-122
lines changed

athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/Constants.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,7 @@ protected Constants()
5353
public static final int PREFIX_LEN = PREFIX_KEY.length();
5454

5555
public static final String GREMLIN_QUERY_SUPPORT_TYPE = "valueMap";
56+
public static final String RDF_COMPONENT_TYPE = "rdf";
57+
public static final String GREMLIN_QUERY_LIMIT = "limit(10)";
58+
public static final String SPARQL_QUERY_LIMIT = "limit 10";
5659
}

athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandler.java

Lines changed: 43 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import com.amazonaws.athena.connector.lambda.QueryStatusChecker;
2323
import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
2424
import com.amazonaws.athena.connector.lambda.data.BlockWriter;
25-
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
2625
import com.amazonaws.athena.connector.lambda.domain.Split;
2726
import com.amazonaws.athena.connector.lambda.domain.TableName;
2827
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
@@ -42,7 +41,8 @@
4241
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
4342
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
4443
import com.amazonaws.athena.connectors.neptune.propertygraph.PropertyGraphHandler;
45-
import com.amazonaws.athena.connectors.neptune.qpt.NeptuneQueryPassthrough;
44+
import com.amazonaws.athena.connectors.neptune.qpt.NeptuneGremlinQueryPassthrough;
45+
import com.amazonaws.athena.connectors.neptune.qpt.NeptuneSparqlQueryPassthrough;
4646
import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection;
4747
import com.google.common.collect.ImmutableMap;
4848
import org.apache.arrow.util.VisibleForTesting;
@@ -67,6 +67,9 @@
6767
import java.util.NoSuchElementException;
6868
import java.util.Set;
6969

70+
import static com.amazonaws.athena.connectors.neptune.Constants.GREMLIN_QUERY_LIMIT;
71+
import static com.amazonaws.athena.connectors.neptune.Constants.RDF_COMPONENT_TYPE;
72+
import static com.amazonaws.athena.connectors.neptune.Constants.SPARQL_QUERY_LIMIT;
7073
import static java.util.Objects.requireNonNull;
7174

7275
/**
@@ -94,7 +97,8 @@ public class NeptuneMetadataHandler extends GlueMetadataHandler
9497
private final String glueDBName;
9598

9699
private NeptuneConnection neptuneConnection = null;
97-
private final NeptuneQueryPassthrough queryPassthrough = new NeptuneQueryPassthrough();
100+
private final NeptuneGremlinQueryPassthrough gremlinQueryPassthrough = new NeptuneGremlinQueryPassthrough();
101+
private final NeptuneSparqlQueryPassthrough sparqlQueryPassthrough = new NeptuneSparqlQueryPassthrough();
98102

99103
public NeptuneMetadataHandler(java.util.Map<String, String> configOptions)
100104
{
@@ -127,16 +131,17 @@ protected NeptuneMetadataHandler(
127131
public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request)
128132
{
129133
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
130-
queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);
134+
gremlinQueryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);
135+
sparqlQueryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);
131136

132137
return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
133138
}
134-
139+
135140
/**
136141
* Since the entire Neptune cluster is considered as a single graph database,
137142
* just return the glue database name provided as a single database (schema)
138143
* name.
139-
*
144+
*
140145
* @param allocator Tool for creating and managing Apache Arrow Blocks.
141146
* @param request Provides details on who made the request and which Athena
142147
* catalog they are querying.
@@ -158,7 +163,7 @@ public ListSchemasResponse doListSchemaNames(BlockAllocator allocator, ListSchem
158163
/**
159164
* Used to get the list of tables that this data source contains. In this case,
160165
* fetch list of tables in the Glue database provided.
161-
*
166+
*
162167
* @param allocator Tool for creating and managing Apache Arrow Blocks.
163168
* @param request Provides details on who made the request and which Athena
164169
* catalog and database they are querying.
@@ -191,7 +196,7 @@ public ListTablesResponse doListTables(BlockAllocator allocator, ListTablesReque
191196
/**
192197
* Used to get definition (field names, types, descriptions, etc...) of a Table.
193198
*
194-
* @param allocator Tool for creating and managing Apache Arrow Blocks.
199+
* @param blockAllocator Tool for creating and managing Apache Arrow Blocks.
195200
* @param request Provides details on who made the request and which Athena
196201
* catalog, database, and table they are querying.
197202
* @return A GetTableResponse which primarily contains: 1. An Apache Arrow
@@ -209,10 +214,10 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
209214
Schema tableSchema = null;
210215
try {
211216
if (glue != null) {
212-
tableSchema = super.doGetTable(blockAllocator, request).getSchema();
217+
tableSchema = super.doGetTable(blockAllocator, request).getSchema();
213218
logger.info("doGetTable: Retrieved schema for table[{}] from AWS Glue.", request.getTableName());
214219
}
215-
}
220+
}
216221
catch (RuntimeException ex) {
217222
logger.warn("doGetTable: Unable to retrieve table[{}:{}] from AWS Glue.",
218223
request.getTableName().getSchemaName(), request.getTableName().getTableName(), ex);
@@ -226,7 +231,7 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
226231
*/
227232
@Override
228233
public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request,
229-
QueryStatusChecker queryStatusChecker) throws Exception
234+
QueryStatusChecker queryStatusChecker) throws Exception
230235
{
231236
// No implemenation as connector doesn't support partitioning
232237
}
@@ -255,7 +260,7 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request
255260
* the RecordHandler has easy access to it.
256261
*/
257262
@Override
258-
public GetSplitsResponse doGetSplits(BlockAllocator blockAllocator, GetSplitsRequest request)
263+
public GetSplitsResponse doGetSplits(BlockAllocator blockAllocator, GetSplitsRequest request)
259264
{
260265
// Every split must have a unique location if we wish to spill to avoid failures
261266
SpillLocation spillLocation = makeSpillLocation(request);
@@ -266,7 +271,7 @@ public GetSplitsResponse doGetSplits(BlockAllocator blockAllocator, GetSplitsReq
266271
}
267272

268273
@Override
269-
protected Field convertField(String name, String glueType)
274+
protected Field convertField(String name, String glueType)
270275
{
271276
return GlueFieldLexer.lex(name, glueType);
272277
}
@@ -275,15 +280,17 @@ protected Field convertField(String name, String glueType)
275280
public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception
276281
{
277282
Map<String, String> qptArguments = request.getQueryPassthroughArguments();
278-
queryPassthrough.verify(qptArguments);
279-
String schemaName = qptArguments.get(NeptuneQueryPassthrough.DATABASE);
280-
String tableName = qptArguments.get(NeptuneQueryPassthrough.COLLECTION);
281-
TableName tableNameObj = new TableName(schemaName, tableName);
282-
request = new GetTableRequest(request.getIdentity(), request.getQueryId(),
283-
request.getCatalogName(), tableNameObj, request.getQueryPassthroughArguments());
283+
if (qptArguments.containsKey(NeptuneGremlinQueryPassthrough.TRAVERSE)) {
284+
gremlinQueryPassthrough.verify(qptArguments);
285+
}
286+
else {
287+
sparqlQueryPassthrough.verify(qptArguments);
288+
}
284289

285-
GetTableResponse getTableResponse = doGetTable(allocator, request);
286-
List<Field> fields = getTableResponse.getSchema().getFields();
290+
String schemaName;
291+
String tableName;
292+
String componentTypeValue;
293+
TableName tableNameObj;
287294
Schema schema;
288295
Enums.GraphType graphType = Enums.GraphType.PROPERTYGRAPH;
289296

@@ -293,19 +300,23 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge
293300

294301
switch (graphType){
295302
case PROPERTYGRAPH:
303+
schemaName = qptArguments.get(NeptuneGremlinQueryPassthrough.DATABASE);
304+
tableName = qptArguments.get(NeptuneGremlinQueryPassthrough.COLLECTION);
305+
componentTypeValue = qptArguments.get(NeptuneGremlinQueryPassthrough.COMPONENT_TYPE);
306+
tableNameObj = new TableName(schemaName, tableName);
296307
Client client = neptuneConnection.getNeptuneClientConnection();
297308
GraphTraversalSource graphTraversalSource = neptuneConnection.getTraversalSource(client);
298-
String gremlinQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY);
299-
gremlinQuery = gremlinQuery.concat(".limit(1)");
309+
String gremlinQuery = qptArguments.get(NeptuneGremlinQueryPassthrough.TRAVERSE);
310+
gremlinQuery = gremlinQuery.concat("." + GREMLIN_QUERY_LIMIT);
300311
logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema gremlinQuery with limit: " + gremlinQuery);
301312
Object object = new PropertyGraphHandler(neptuneConnection).getResponseFromGremlinQuery(graphTraversalSource, gremlinQuery);
302313
GraphTraversal graphTraversalForSchema = (GraphTraversal) object;
303314
if (graphTraversalForSchema.hasNext()) {
304315
Object responseObj = graphTraversalForSchema.next();
305-
if (responseObj instanceof Map && gremlinQuery.contains(Constants.GREMLIN_QUERY_SUPPORT_TYPE)) {
316+
if (responseObj instanceof Map) {
306317
logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema gremlinQuery with valueMap");
307318
Map graphTraversalObj = (Map) responseObj;
308-
schema = getSchemaFromResults(getTableResponse, graphTraversalObj, fields);
319+
schema = NeptuneSchemaUtils.getSchemaFromResults(graphTraversalObj, componentTypeValue, tableName);
309320
return new GetTableResponse(request.getCatalogName(), tableNameObj, schema);
310321
}
311322
else {
@@ -319,16 +330,17 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge
319330
}
320331

321332
case RDF:
322-
String sparqlQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY);
323-
sparqlQuery = sparqlQuery.contains("limit") ? sparqlQuery : sparqlQuery.concat("\nlimit 1");
333+
schemaName = qptArguments.get(NeptuneSparqlQueryPassthrough.DATABASE);
334+
tableName = qptArguments.get(NeptuneSparqlQueryPassthrough.COLLECTION);
335+
tableNameObj = new TableName(schemaName, tableName);
336+
String sparqlQuery = qptArguments.get(NeptuneSparqlQueryPassthrough.QUERY);
337+
sparqlQuery = sparqlQuery.contains("limit") ? sparqlQuery : sparqlQuery.concat("\n" + SPARQL_QUERY_LIMIT);
324338
logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema sparql query with limit: " + sparqlQuery);
325339
NeptuneSparqlConnection neptuneSparqlConnection = (NeptuneSparqlConnection) neptuneConnection;
326340
neptuneSparqlConnection.runQuery(sparqlQuery);
327-
String strim = getTableResponse.getSchema().getCustomMetadata().get(Constants.SCHEMA_STRIP_URI);
328-
boolean trimURI = strim == null ? false : Boolean.parseBoolean(strim);
329341
if (neptuneSparqlConnection.hasNext()) {
330-
Map<String, Object> resultsMap = neptuneSparqlConnection.next(trimURI);
331-
schema = getSchemaFromResults(getTableResponse, resultsMap, fields);
342+
Map<String, Object> resultsMap = neptuneSparqlConnection.next();
343+
schema = NeptuneSchemaUtils.getSchemaFromResults(resultsMap, RDF_COMPONENT_TYPE, tableName);
332344
return new GetTableResponse(request.getCatalogName(), tableNameObj, schema);
333345
}
334346
else {
@@ -339,36 +351,4 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge
339351
throw new IllegalArgumentException("Unsupported graphType: " + graphType);
340352
}
341353
}
342-
343-
private Schema getSchemaFromResults(GetTableResponse getTableResponse, Map resultsMap, List<Field> fields)
344-
{
345-
Schema schema;
346-
//In case of 'gremlin/sparql query' is fetching all columns then we can use same schema from glue
347-
//otherwise we will build schema from gremlin/sparql query result column names.
348-
if (resultsMap != null && resultsMap.size() == fields.size()) {
349-
schema = getTableResponse.getSchema();
350-
}
351-
else {
352-
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
353-
//Building schema from gremlin/sparql query results and list of fields from glue response.
354-
//It's require only when we are selecting limited columns.
355-
resultsMap.forEach((columnName, columnValue) -> buildSchema(columnName.toString(), fields, schemaBuilder));
356-
Map<String, String> metaData = getTableResponse.getSchema().getCustomMetadata();
357-
for (Map.Entry<String, String> map : metaData.entrySet()) {
358-
schemaBuilder.addMetadata(map.getKey(), map.getValue());
359-
}
360-
schema = schemaBuilder.build();
361-
}
362-
return schema;
363-
}
364-
365-
private void buildSchema(String columnName, List<Field> fields, SchemaBuilder schemaBuilder)
366-
{
367-
for (Field field : fields) {
368-
if (field.getName().equals(columnName)) {
369-
schemaBuilder.addField(field);
370-
break;
371-
}
372-
}
373-
}
374354
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*-
2+
* #%L
3+
* athena-neptune
4+
* %%
5+
* Copyright (C) 2019 - 2020 Amazon Web Services
6+
* %%
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* #L%
19+
*/
20+
package com.amazonaws.athena.connectors.neptune;
21+
22+
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
23+
import org.apache.arrow.vector.types.Types;
24+
import org.apache.arrow.vector.types.pojo.Field;
25+
import org.apache.arrow.vector.types.pojo.FieldType;
26+
import org.apache.arrow.vector.types.pojo.Schema;
27+
import org.slf4j.Logger;
28+
import org.slf4j.LoggerFactory;
29+
30+
import java.math.BigInteger;
31+
import java.util.List;
32+
import java.util.Map;
33+
34+
/**
35+
* Collection of helpful utilities that handle Neptune schema inference, type, and naming conversion.
36+
*/
37+
public class NeptuneSchemaUtils
38+
{
39+
private static final Logger logger = LoggerFactory.getLogger(NeptuneSchemaUtils.class);
40+
41+
private NeptuneSchemaUtils() {}
42+
43+
public static Schema getSchemaFromResults(Map resultsMap, String componentTypeValue, String tableName)
44+
{
45+
Schema schema;
46+
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
47+
//Building schema from gremlin/sparql query results.
48+
resultsMap.forEach((columnName, columnValue) -> buildSchema(columnName.toString(), columnValue, schemaBuilder));
49+
schemaBuilder.addMetadata(Constants.SCHEMA_COMPONENT_TYPE, componentTypeValue);
50+
schemaBuilder.addMetadata(Constants.SCHEMA_GLABEL, tableName);
51+
schema = schemaBuilder.build();
52+
return schema;
53+
}
54+
55+
private static void buildSchema(String columnName, Object columnValue, SchemaBuilder schemaBuilder)
56+
{
57+
schemaBuilder.addField(getArrowFieldForNeptune(columnName, columnValue));
58+
}
59+
60+
/**
61+
* Infers the type of a field from Neptune data.
62+
*
63+
* @param key The key of the field we are attempting to infer.
64+
* @param value A value from the key whose type we are attempting to infer.
65+
* @return The Apache Arrow field definition of the inferred key/value.
66+
*/
67+
private static Field getArrowFieldForNeptune(String key, Object value)
68+
{
69+
if (value instanceof String || value instanceof java.util.UUID) {
70+
return new Field(key, FieldType.nullable(Types.MinorType.VARCHAR.getType()), null);
71+
}
72+
else if (value instanceof Integer) {
73+
return new Field(key, FieldType.nullable(Types.MinorType.INT.getType()), null);
74+
}
75+
else if (value instanceof BigInteger) {
76+
return new Field(key, FieldType.nullable(Types.MinorType.BIGINT.getType()), null);
77+
}
78+
else if (value instanceof Long) {
79+
return new Field(key, FieldType.nullable(Types.MinorType.BIGINT.getType()), null);
80+
}
81+
else if (value instanceof Boolean) {
82+
return new Field(key, FieldType.nullable(Types.MinorType.BIT.getType()), null);
83+
}
84+
else if (value instanceof Float) {
85+
return new Field(key, FieldType.nullable(Types.MinorType.FLOAT4.getType()), null);
86+
}
87+
else if (value instanceof Double) {
88+
return new Field(key, FieldType.nullable(Types.MinorType.FLOAT8.getType()), null);
89+
}
90+
else if (value instanceof java.util.Date) {
91+
return new Field(key, FieldType.nullable(Types.MinorType.DATEMILLI.getType()), null);
92+
}
93+
else if (value instanceof List) {
94+
return getArrowFieldForNeptune(key, ((List<?>) value).get(0));
95+
}
96+
97+
String className = (value == null || value.getClass() == null) ? "null" : value.getClass().getName();
98+
logger.warn("Unknown type[{}] for field[{}], defaulting to varchar.", className, key);
99+
return new Field(key, FieldType.nullable(Types.MinorType.VARCHAR.getType()), null);
100+
}
101+
}

athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/propertygraph/PropertyGraphHandler.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import com.amazonaws.athena.connectors.neptune.propertygraph.rowwriters.CustomSchemaRowWriter;
3131
import com.amazonaws.athena.connectors.neptune.propertygraph.rowwriters.EdgeRowWriter;
3232
import com.amazonaws.athena.connectors.neptune.propertygraph.rowwriters.VertexRowWriter;
33-
import com.amazonaws.athena.connectors.neptune.qpt.NeptuneQueryPassthrough;
33+
import com.amazonaws.athena.connectors.neptune.qpt.NeptuneGremlinQueryPassthrough;
3434
import org.apache.arrow.util.VisibleForTesting;
3535
import org.apache.arrow.vector.types.pojo.Field;
3636
import org.apache.tinkerpop.gremlin.driver.Client;
@@ -73,7 +73,7 @@ public class PropertyGraphHandler
7373
*/
7474

7575
private final NeptuneConnection neptuneConnection;
76-
private final NeptuneQueryPassthrough queryPassthrough = new NeptuneQueryPassthrough();
76+
private final NeptuneGremlinQueryPassthrough queryPassthrough = new NeptuneGremlinQueryPassthrough();
7777

7878
@VisibleForTesting
7979
public PropertyGraphHandler(NeptuneConnection neptuneConnection)
@@ -117,8 +117,8 @@ public void executeQuery(
117117
if (recordsRequest.getConstraints().isQueryPassThrough()) {
118118
Map<String, String> qptArguments = recordsRequest.getConstraints().getQueryPassthroughArguments();
119119
queryPassthrough.verify(qptArguments);
120-
labelName = qptArguments.get(NeptuneQueryPassthrough.COLLECTION);
121-
gremlinQuery = qptArguments.get(NeptuneQueryPassthrough.QUERY);
120+
labelName = qptArguments.get(NeptuneGremlinQueryPassthrough.COLLECTION);
121+
gremlinQuery = qptArguments.get(NeptuneGremlinQueryPassthrough.TRAVERSE);
122122

123123
Object object = getResponseFromGremlinQuery(graphTraversalSource, gremlinQuery);
124124
graphTraversal = (GraphTraversal) object;

0 commit comments

Comments
 (0)