2222import com .amazonaws .athena .connector .lambda .QueryStatusChecker ;
2323import com .amazonaws .athena .connector .lambda .data .BlockAllocator ;
2424import com .amazonaws .athena .connector .lambda .data .BlockWriter ;
25- import com .amazonaws .athena .connector .lambda .data .SchemaBuilder ;
2625import com .amazonaws .athena .connector .lambda .domain .Split ;
2726import com .amazonaws .athena .connector .lambda .domain .TableName ;
2827import com .amazonaws .athena .connector .lambda .domain .spill .SpillLocation ;
4241import com .amazonaws .athena .connector .lambda .metadata .optimizations .OptimizationSubType ;
4342import com .amazonaws .athena .connector .lambda .security .EncryptionKeyFactory ;
4443import com .amazonaws .athena .connectors .neptune .propertygraph .PropertyGraphHandler ;
45- import com .amazonaws .athena .connectors .neptune .qpt .NeptuneQueryPassthrough ;
44+ import com .amazonaws .athena .connectors .neptune .qpt .NeptuneGremlinQueryPassthrough ;
45+ import com .amazonaws .athena .connectors .neptune .qpt .NeptuneSparqlQueryPassthrough ;
4646import com .amazonaws .athena .connectors .neptune .rdf .NeptuneSparqlConnection ;
4747import com .google .common .collect .ImmutableMap ;
4848import org .apache .arrow .util .VisibleForTesting ;
6767import java .util .NoSuchElementException ;
6868import java .util .Set ;
6969
70+ import static com .amazonaws .athena .connectors .neptune .Constants .GREMLIN_QUERY_LIMIT ;
71+ import static com .amazonaws .athena .connectors .neptune .Constants .RDF_COMPONENT_TYPE ;
72+ import static com .amazonaws .athena .connectors .neptune .Constants .SPARQL_QUERY_LIMIT ;
7073import static java .util .Objects .requireNonNull ;
7174
7275/**
@@ -94,7 +97,8 @@ public class NeptuneMetadataHandler extends GlueMetadataHandler
9497 private final String glueDBName ;
9598
9699 private NeptuneConnection neptuneConnection = null ;
97- private final NeptuneQueryPassthrough queryPassthrough = new NeptuneQueryPassthrough ();
100+ private final NeptuneGremlinQueryPassthrough gremlinQueryPassthrough = new NeptuneGremlinQueryPassthrough ();
101+ private final NeptuneSparqlQueryPassthrough sparqlQueryPassthrough = new NeptuneSparqlQueryPassthrough ();
98102
99103 public NeptuneMetadataHandler (java .util .Map <String , String > configOptions )
100104 {
@@ -127,16 +131,17 @@ protected NeptuneMetadataHandler(
127131 public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities (BlockAllocator allocator , GetDataSourceCapabilitiesRequest request )
128132 {
129133 ImmutableMap .Builder <String , List <OptimizationSubType >> capabilities = ImmutableMap .builder ();
130- queryPassthrough .addQueryPassthroughCapabilityIfEnabled (capabilities , configOptions );
134+ gremlinQueryPassthrough .addQueryPassthroughCapabilityIfEnabled (capabilities , configOptions );
135+ sparqlQueryPassthrough .addQueryPassthroughCapabilityIfEnabled (capabilities , configOptions );
131136
132137 return new GetDataSourceCapabilitiesResponse (request .getCatalogName (), capabilities .build ());
133138 }
134-
139+
135140 /**
136141 * Since the entire Neptune cluster is considered as a single graph database,
137142 * just return the glue database name provided as a single database (schema)
138143 * name.
139- *
144+ *
140145 * @param allocator Tool for creating and managing Apache Arrow Blocks.
141146 * @param request Provides details on who made the request and which Athena
142147 * catalog they are querying.
@@ -158,7 +163,7 @@ public ListSchemasResponse doListSchemaNames(BlockAllocator allocator, ListSchem
158163 /**
159164 * Used to get the list of tables that this data source contains. In this case,
160165 * fetch list of tables in the Glue database provided.
161- *
166+ *
162167 * @param allocator Tool for creating and managing Apache Arrow Blocks.
163168 * @param request Provides details on who made the request and which Athena
164169 * catalog and database they are querying.
@@ -191,7 +196,7 @@ public ListTablesResponse doListTables(BlockAllocator allocator, ListTablesReque
191196 /**
192197 * Used to get definition (field names, types, descriptions, etc...) of a Table.
193198 *
194- * @param allocator Tool for creating and managing Apache Arrow Blocks.
199+ * @param blockAllocator Tool for creating and managing Apache Arrow Blocks.
195200 * @param request Provides details on who made the request and which Athena
196201 * catalog, database, and table they are querying.
197202 * @return A GetTableResponse which primarily contains: 1. An Apache Arrow
@@ -209,10 +214,10 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
209214 Schema tableSchema = null ;
210215 try {
211216 if (glue != null ) {
212- tableSchema = super .doGetTable (blockAllocator , request ).getSchema ();
217+ tableSchema = super .doGetTable (blockAllocator , request ).getSchema ();
213218 logger .info ("doGetTable: Retrieved schema for table[{}] from AWS Glue." , request .getTableName ());
214219 }
215- }
220+ }
216221 catch (RuntimeException ex ) {
217222 logger .warn ("doGetTable: Unable to retrieve table[{}:{}] from AWS Glue." ,
218223 request .getTableName ().getSchemaName (), request .getTableName ().getTableName (), ex );
@@ -226,7 +231,7 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
226231 */
227232 @ Override
228233 public void getPartitions (BlockWriter blockWriter , GetTableLayoutRequest request ,
229- QueryStatusChecker queryStatusChecker ) throws Exception
234+ QueryStatusChecker queryStatusChecker ) throws Exception
230235 {
231236 // No implemenation as connector doesn't support partitioning
232237 }
@@ -255,7 +260,7 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request
255260 * the RecordHandler has easy access to it.
256261 */
257262 @ Override
258- public GetSplitsResponse doGetSplits (BlockAllocator blockAllocator , GetSplitsRequest request )
263+ public GetSplitsResponse doGetSplits (BlockAllocator blockAllocator , GetSplitsRequest request )
259264 {
260265 // Every split must have a unique location if we wish to spill to avoid failures
261266 SpillLocation spillLocation = makeSpillLocation (request );
@@ -266,7 +271,7 @@ public GetSplitsResponse doGetSplits(BlockAllocator blockAllocator, GetSplitsReq
266271 }
267272
268273 @ Override
269- protected Field convertField (String name , String glueType )
274+ protected Field convertField (String name , String glueType )
270275 {
271276 return GlueFieldLexer .lex (name , glueType );
272277 }
@@ -275,15 +280,17 @@ protected Field convertField(String name, String glueType)
275280 public GetTableResponse doGetQueryPassthroughSchema (BlockAllocator allocator , GetTableRequest request ) throws Exception
276281 {
277282 Map <String , String > qptArguments = request .getQueryPassthroughArguments ();
278- queryPassthrough . verify (qptArguments );
279- String schemaName = qptArguments . get ( NeptuneQueryPassthrough . DATABASE );
280- String tableName = qptArguments . get ( NeptuneQueryPassthrough . COLLECTION );
281- TableName tableNameObj = new TableName ( schemaName , tableName );
282- request = new GetTableRequest ( request . getIdentity (), request . getQueryId (),
283- request . getCatalogName (), tableNameObj , request . getQueryPassthroughArguments ());
283+ if (qptArguments . containsKey ( NeptuneGremlinQueryPassthrough . TRAVERSE )) {
284+ gremlinQueryPassthrough . verify ( qptArguments );
285+ }
286+ else {
287+ sparqlQueryPassthrough . verify ( qptArguments );
288+ }
284289
285- GetTableResponse getTableResponse = doGetTable (allocator , request );
286- List <Field > fields = getTableResponse .getSchema ().getFields ();
290+ String schemaName ;
291+ String tableName ;
292+ String componentTypeValue ;
293+ TableName tableNameObj ;
287294 Schema schema ;
288295 Enums .GraphType graphType = Enums .GraphType .PROPERTYGRAPH ;
289296
@@ -293,19 +300,23 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge
293300
294301 switch (graphType ){
295302 case PROPERTYGRAPH :
303+ schemaName = qptArguments .get (NeptuneGremlinQueryPassthrough .DATABASE );
304+ tableName = qptArguments .get (NeptuneGremlinQueryPassthrough .COLLECTION );
305+ componentTypeValue = qptArguments .get (NeptuneGremlinQueryPassthrough .COMPONENT_TYPE );
306+ tableNameObj = new TableName (schemaName , tableName );
296307 Client client = neptuneConnection .getNeptuneClientConnection ();
297308 GraphTraversalSource graphTraversalSource = neptuneConnection .getTraversalSource (client );
298- String gremlinQuery = qptArguments .get (NeptuneQueryPassthrough . QUERY );
299- gremlinQuery = gremlinQuery .concat (".limit(1)" );
309+ String gremlinQuery = qptArguments .get (NeptuneGremlinQueryPassthrough . TRAVERSE );
310+ gremlinQuery = gremlinQuery .concat ("." + GREMLIN_QUERY_LIMIT );
300311 logger .info ("NeptuneMetadataHandler doGetQueryPassthroughSchema gremlinQuery with limit: " + gremlinQuery );
301312 Object object = new PropertyGraphHandler (neptuneConnection ).getResponseFromGremlinQuery (graphTraversalSource , gremlinQuery );
302313 GraphTraversal graphTraversalForSchema = (GraphTraversal ) object ;
303314 if (graphTraversalForSchema .hasNext ()) {
304315 Object responseObj = graphTraversalForSchema .next ();
305- if (responseObj instanceof Map && gremlinQuery . contains ( Constants . GREMLIN_QUERY_SUPPORT_TYPE ) ) {
316+ if (responseObj instanceof Map ) {
306317 logger .info ("NeptuneMetadataHandler doGetQueryPassthroughSchema gremlinQuery with valueMap" );
307318 Map graphTraversalObj = (Map ) responseObj ;
308- schema = getSchemaFromResults (getTableResponse , graphTraversalObj , fields );
319+ schema = NeptuneSchemaUtils . getSchemaFromResults (graphTraversalObj , componentTypeValue , tableName );
309320 return new GetTableResponse (request .getCatalogName (), tableNameObj , schema );
310321 }
311322 else {
@@ -319,16 +330,17 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge
319330 }
320331
321332 case RDF :
322- String sparqlQuery = qptArguments .get (NeptuneQueryPassthrough .QUERY );
323- sparqlQuery = sparqlQuery .contains ("limit" ) ? sparqlQuery : sparqlQuery .concat ("\n limit 1" );
333+ schemaName = qptArguments .get (NeptuneSparqlQueryPassthrough .DATABASE );
334+ tableName = qptArguments .get (NeptuneSparqlQueryPassthrough .COLLECTION );
335+ tableNameObj = new TableName (schemaName , tableName );
336+ String sparqlQuery = qptArguments .get (NeptuneSparqlQueryPassthrough .QUERY );
337+ sparqlQuery = sparqlQuery .contains ("limit" ) ? sparqlQuery : sparqlQuery .concat ("\n " + SPARQL_QUERY_LIMIT );
324338 logger .info ("NeptuneMetadataHandler doGetQueryPassthroughSchema sparql query with limit: " + sparqlQuery );
325339 NeptuneSparqlConnection neptuneSparqlConnection = (NeptuneSparqlConnection ) neptuneConnection ;
326340 neptuneSparqlConnection .runQuery (sparqlQuery );
327- String strim = getTableResponse .getSchema ().getCustomMetadata ().get (Constants .SCHEMA_STRIP_URI );
328- boolean trimURI = strim == null ? false : Boolean .parseBoolean (strim );
329341 if (neptuneSparqlConnection .hasNext ()) {
330- Map <String , Object > resultsMap = neptuneSparqlConnection .next (trimURI );
331- schema = getSchemaFromResults (getTableResponse , resultsMap , fields );
342+ Map <String , Object > resultsMap = neptuneSparqlConnection .next ();
343+ schema = NeptuneSchemaUtils . getSchemaFromResults (resultsMap , RDF_COMPONENT_TYPE , tableName );
332344 return new GetTableResponse (request .getCatalogName (), tableNameObj , schema );
333345 }
334346 else {
@@ -339,36 +351,4 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge
339351 throw new IllegalArgumentException ("Unsupported graphType: " + graphType );
340352 }
341353 }
342-
343- private Schema getSchemaFromResults (GetTableResponse getTableResponse , Map resultsMap , List <Field > fields )
344- {
345- Schema schema ;
346- //In case of 'gremlin/sparql query' is fetching all columns then we can use same schema from glue
347- //otherwise we will build schema from gremlin/sparql query result column names.
348- if (resultsMap != null && resultsMap .size () == fields .size ()) {
349- schema = getTableResponse .getSchema ();
350- }
351- else {
352- SchemaBuilder schemaBuilder = SchemaBuilder .newBuilder ();
353- //Building schema from gremlin/sparql query results and list of fields from glue response.
354- //It's require only when we are selecting limited columns.
355- resultsMap .forEach ((columnName , columnValue ) -> buildSchema (columnName .toString (), fields , schemaBuilder ));
356- Map <String , String > metaData = getTableResponse .getSchema ().getCustomMetadata ();
357- for (Map .Entry <String , String > map : metaData .entrySet ()) {
358- schemaBuilder .addMetadata (map .getKey (), map .getValue ());
359- }
360- schema = schemaBuilder .build ();
361- }
362- return schema ;
363- }
364-
365- private void buildSchema (String columnName , List <Field > fields , SchemaBuilder schemaBuilder )
366- {
367- for (Field field : fields ) {
368- if (field .getName ().equals (columnName )) {
369- schemaBuilder .addField (field );
370- break ;
371- }
372- }
373- }
374354}
0 commit comments