2424import com .amazonaws .athena .connector .lambda .data .BlockAllocatorImpl ;
2525import com .amazonaws .athena .connector .lambda .domain .TableName ;
2626import com .amazonaws .athena .connector .lambda .domain .predicate .Constraints ;
27+ import com .amazonaws .athena .connector .lambda .exceptions .AthenaConnectorException ;
28+ import com .amazonaws .athena .connector .lambda .metadata .GetDataSourceCapabilitiesRequest ;
29+ import com .amazonaws .athena .connector .lambda .metadata .GetDataSourceCapabilitiesResponse ;
2730import com .amazonaws .athena .connector .lambda .metadata .GetSplitsRequest ;
2831import com .amazonaws .athena .connector .lambda .metadata .GetSplitsResponse ;
2932import com .amazonaws .athena .connector .lambda .metadata .GetTableRequest ;
3235import com .amazonaws .athena .connector .lambda .metadata .ListSchemasResponse ;
3336import com .amazonaws .athena .connector .lambda .metadata .ListTablesRequest ;
3437import com .amazonaws .athena .connector .lambda .metadata .ListTablesResponse ;
38+ import com .amazonaws .athena .connector .lambda .metadata .optimizations .OptimizationSubType ;
3539import com .amazonaws .athena .connector .lambda .security .FederatedIdentity ;
3640import com .google .cloud .bigquery .BigQuery ;
41+ import com .google .cloud .bigquery .BigQueryOptions ;
3742import com .google .cloud .bigquery .Dataset ;
3843import com .google .cloud .bigquery .DatasetId ;
3944import com .google .cloud .bigquery .Field ;
4045import com .google .cloud .bigquery .FieldList ;
4146import com .google .cloud .bigquery .FieldValue ;
4247import com .google .cloud .bigquery .FieldValueList ;
4348import com .google .cloud .bigquery .Job ;
49+ import com .google .cloud .bigquery .JobInfo ;
50+ import com .google .cloud .bigquery .JobStatistics ;
4451import com .google .cloud .bigquery .JobStatus ;
4552import com .google .cloud .bigquery .LegacySQLTypeName ;
4653import com .google .cloud .bigquery .Schema ;
6269import com .amazonaws .athena .connector .lambda .security .LocalKeyFactory ;
6370
6471import java .io .IOException ;
65- import java .util .Arrays ;
6672import java .util .Collections ;
6773import java .util .HashMap ;
6874import java .util .List ;
6975import java .util .Map ;
7076
7177import static com .amazonaws .athena .connector .lambda .domain .predicate .Constraints .DEFAULT_NO_LIMIT ;
7278import static com .amazonaws .athena .connector .lambda .metadata .ListTablesRequest .UNLIMITED_PAGE_SIZE_VALUE ;
79+ import static com .amazonaws .athena .connector .lambda .metadata .optimizations .querypassthrough .QueryPassthroughSignature .ENABLE_QUERY_PASSTHROUGH ;
80+ import static com .amazonaws .athena .connector .lambda .metadata .optimizations .querypassthrough .QueryPassthroughSignature .SCHEMA_FUNCTION_NAME ;
81+ import static com .amazonaws .athena .connectors .google .bigquery .qpt .BigQueryQueryPassthrough .NAME ;
82+ import static com .amazonaws .athena .connectors .google .bigquery .qpt .BigQueryQueryPassthrough .QUERY ;
83+ import static com .amazonaws .athena .connectors .google .bigquery .qpt .BigQueryQueryPassthrough .SCHEMA_NAME ;
7384import static org .junit .Assert .assertEquals ;
85+ import static org .junit .Assert .assertFalse ;
7486import static org .junit .Assert .assertNotNull ;
87+ import static org .junit .Assert .assertThrows ;
88+ import static org .junit .Assert .assertTrue ;
7589import static org .mockito .ArgumentMatchers .any ;
7690import static org .mockito .ArgumentMatchers .nullable ;
7791import static org .mockito .Mockito .mock ;
92+ import static org .mockito .Mockito .mockStatic ;
93+ import static org .mockito .Mockito .verify ;
7894import static org .mockito .Mockito .when ;
7995
8096@ RunWith (MockitoJUnitRunner .class )
8197public class BigQueryMetadataHandlerTest
8298{
8399 private static final String QUERY_ID = "queryId" ;
84100 private static final String CATALOG = "catalog" ;
101+ private static final String SCHEMA = "testSchema" ;
102+ private static final String TABLE = "testTable" ;
85103 private static final TableName TABLE_NAME = new TableName ("dataset1" , "table1" );
86104
87105 @ Mock
@@ -104,10 +122,9 @@ public class BigQueryMetadataHandlerTest
104122 private MockedStatic <BigQueryUtils > mockedStatic ;
105123
106124 @ Before
107- public void setUp () throws InterruptedException , IOException
108- {
125+ public void setUp () {
109126 System .setProperty ("aws.region" , "us-east-1" );
110- MockitoAnnotations .initMocks (this );
127+ MockitoAnnotations .openMocks (this );
111128
112129 // Mock the SecretsManager response
113130 GetSecretValueResponse secretResponse = GetSecretValueResponse .builder ()
@@ -245,24 +262,23 @@ public void testDoGetTable() throws IOException
245262 }
246263
247264 @ Test
248- public void testDoGetSplits () throws Exception
249- {
265+ public void testDoGetSplits () {
250266
251267// mockedStatic.when(() -> BigQueryUtils.fixCaseForDatasetName(any(String.class), any(String.class), any(BigQuery.class))).thenReturn("testDataset");
252268// mockedStatic.when(() -> BigQueryUtils.fixCaseForTableName(any(String.class), any(String.class), any(String.class), any(BigQuery.class))).thenReturn("testTable");
253269 BlockAllocator blockAllocator = new BlockAllocatorImpl ();
254270 GetSplitsRequest request = new GetSplitsRequest (federatedIdentity ,
255271 QUERY_ID , CATALOG , TABLE_NAME ,
256- mock (Block .class ), Collections .< String > emptyList (), new Constraints (new HashMap <>(), Collections .emptyList (), Collections .emptyList (), DEFAULT_NO_LIMIT , Collections .emptyMap (), null ), null );
272+ mock (Block .class ), Collections .emptyList (), new Constraints (new HashMap <>(), Collections .emptyList (), Collections .emptyList (), DEFAULT_NO_LIMIT , Collections .emptyMap (), null ), null );
257273 // added schema with integer column countCol
258- List <Field > testSchemaFields = Arrays . asList (Field .of ("countCol" , LegacySQLTypeName .INTEGER ));
274+ List <Field > testSchemaFields = List . of (Field .of ("countCol" , LegacySQLTypeName .INTEGER ));
259275 Schema tableSchema = Schema .of (testSchemaFields );
260276
261277 // mocked table row count as 15
262- List <FieldValue > bigQueryRowValue = Arrays . asList (FieldValue .of (FieldValue .Attribute .PRIMITIVE , "15" ));
278+ List <FieldValue > bigQueryRowValue = List . of (FieldValue .of (FieldValue .Attribute .PRIMITIVE , "15" ));
263279 FieldValueList fieldValueList = FieldValueList .of (bigQueryRowValue ,
264280 FieldList .of (testSchemaFields ));
265- List <FieldValueList > tableRows = Arrays . asList (fieldValueList );
281+ List <FieldValueList > tableRows = List . of (fieldValueList );
266282
267283 GetSplitsResponse response = bigQueryMetadataHandler .doGetSplits (blockAllocator , request );
268284
@@ -288,4 +304,131 @@ public void testDoGetDataSourceCapabilities()
288304 assertNotNull (response );
289305 assertNotNull (response .getCapabilities ());
290306 }
307+
308+ @ Test
309+ public void testDoGetDataSourceCapabilities_VerifyOptimizations ()
310+ {
311+ GetDataSourceCapabilitiesRequest request =
312+ new GetDataSourceCapabilitiesRequest (federatedIdentity , QUERY_ID , CATALOG );
313+
314+ GetDataSourceCapabilitiesResponse response =
315+ bigQueryMetadataHandler .doGetDataSourceCapabilities (blockAllocator , request );
316+
317+ Map <String , List <OptimizationSubType >> capabilities = response .getCapabilities ();
318+
319+ assertEquals (CATALOG , response .getCatalogName ());
320+
321+ // Filter pushdown
322+ List <OptimizationSubType > filterPushdown = capabilities .get ("supports_filter_pushdown" );
323+ assertNotNull ("Expected supports_filter_pushdown capability to be present" , filterPushdown );
324+ assertEquals (2 , filterPushdown .size ());
325+ assertTrue (filterPushdown .stream ().anyMatch (subType -> subType .getSubType ().equals ("sorted_range_set" )));
326+ assertTrue (filterPushdown .stream ().anyMatch (subType -> subType .getSubType ().equals ("nullable_comparison" )));
327+
328+ // Complex expression pushdown
329+ List <OptimizationSubType > complexPushdown = capabilities .get ("supports_complex_expression_pushdown" );
330+ assertNotNull ("Expected supports_complex_expression_pushdown capability to be present" , complexPushdown );
331+ assertEquals (1 , complexPushdown .size ());
332+ OptimizationSubType complexSubType = complexPushdown .get (0 );
333+ assertEquals ("supported_function_expression_types" , complexSubType .getSubType ());
334+ assertNotNull ("Expected function expression types to be present" , complexSubType .getProperties ());
335+ assertFalse ("Expected function expression types to be non-empty" , complexSubType .getProperties ().isEmpty ());
336+
337+ // Top-N pushdown
338+ List <OptimizationSubType > topNPushdown = capabilities .get ("supports_top_n_pushdown" );
339+ assertNotNull ("Expected supports_top_n_pushdown capability to be present" , topNPushdown );
340+ assertEquals (1 , topNPushdown .size ());
341+ assertEquals ("SUPPORTS_ORDER_BY" , topNPushdown .get (0 ).getSubType ());
342+ }
343+
344+ @ Test
345+ public void testDoGetQueryPassthroughSchema_WhenEnabled_ShouldGetSucceeded () throws Exception {
346+ Map <String , String > queryPassthroughParameters = Map .of (
347+ SCHEMA_FUNCTION_NAME , "system.query" ,
348+ ENABLE_QUERY_PASSTHROUGH , "true" ,
349+ NAME , "query" ,
350+ SCHEMA_NAME , "system" ,
351+ QUERY , "select col1 from testTable" );
352+
353+ GetTableRequest getTableRequest = getTableRequest (queryPassthroughParameters );
354+
355+ try (MockedStatic <BigQueryOptions > mockBigQueryOptions = mockStatic (BigQueryOptions .class )) {
356+ BigQueryOptions options = mock (BigQueryOptions .class );
357+ mockBigQueryOptions .when (BigQueryOptions ::getDefaultInstance ).thenReturn (options );
358+ when (options .getService ()).thenReturn (bigQuery );
359+
360+ Field field = Field .of ("column1" , LegacySQLTypeName .STRING );
361+ com .google .cloud .bigquery .Schema schema = com .google .cloud .bigquery .Schema .of (field );
362+
363+ JobStatistics .QueryStatistics stats = mock (JobStatistics .QueryStatistics .class );
364+ when (stats .getStatementType ()).thenReturn (JobStatistics .QueryStatistics .StatementType .SELECT );
365+ when (stats .getSchema ()).thenReturn (schema );
366+ when (bigQuery .create (any (JobInfo .class ))).thenReturn (job );
367+ when (job .getStatistics ()).thenReturn (stats );
368+
369+ GetTableResponse response = bigQueryMetadataHandler .doGetQueryPassthroughSchema (blockAllocator , getTableRequest );
370+ assertNotNull (response );
371+ assertEquals (CATALOG , response .getCatalogName ());
372+ assertEquals (1 , response .getSchema ().getFields ().size ());
373+ verify (bigQuery ).create (any (JobInfo .class ));
374+ }
375+ }
376+
377+ @ Test
378+ public void testDoGetQueryPassthroughSchema_WithMissingQueryArg () {
379+ // Required QUERY parameter is missing
380+ Map <String , String > queryPassthroughParameters = Map .of (
381+ SCHEMA_FUNCTION_NAME , "system.query" ,
382+ ENABLE_QUERY_PASSTHROUGH , "true" ,
383+ NAME , "query" ,
384+ SCHEMA_NAME , "system" );
385+
386+ executeAndAssertTest (queryPassthroughParameters , "Missing Query Passthrough Argument: QUERY" );
387+ }
388+
389+ @ Test
390+ public void testDoGetQueryPassthroughSchema_WithMissingQueryValue () {
391+ // Required QUERY parameter value is missing
392+ Map <String , String > queryPassthroughParameters = Map .of (
393+ SCHEMA_FUNCTION_NAME , "system.query" ,
394+ ENABLE_QUERY_PASSTHROUGH , "true" ,
395+ NAME , "query" ,
396+ SCHEMA_NAME , "system" ,
397+ QUERY , "" );
398+ executeAndAssertTest (queryPassthroughParameters , "Missing Query Passthrough Value for Argument: QUERY" );
399+ }
400+
401+ @ Test
402+ public void testDoGetQueryPassthroughSchema_WithWrongSchemaNameArg () {
403+ // Schema function name is incorrect
404+ Map <String , String > queryPassthroughParameters = Map .of (
405+ SCHEMA_FUNCTION_NAME , "wrong.query" ,
406+ ENABLE_QUERY_PASSTHROUGH , "true" ,
407+ NAME , "query" ,
408+ SCHEMA_NAME , "system" ,
409+ QUERY , "select col1 from testTable" );
410+ executeAndAssertTest (queryPassthroughParameters , "Function Signature doesn't match implementation's" );
411+ }
412+
413+ private void executeAndAssertTest (Map <String , String > queryPassthroughParameters , String errorMessage ) {
414+ GetTableRequest getTableRequest = getTableRequest (queryPassthroughParameters );
415+
416+ Exception e = assertThrows (AthenaConnectorException .class , () ->
417+ bigQueryMetadataHandler .doGetQueryPassthroughSchema (blockAllocator , getTableRequest ));
418+ assertTrue (e .getMessage ().contains (errorMessage ));
419+ }
420+
421+ private GetTableRequest getTableRequest (Map <String , String > queryPassthroughParameters ) {
422+ return new GetTableRequest (federatedIdentity ,
423+ QUERY_ID , CATALOG ,
424+ new TableName (SCHEMA , TABLE ), queryPassthroughParameters );
425+ }
426+
427+ @ Test
428+ public void testDoGetQueryPassthroughSchema_WhenDisabled_ShouldThrowException () {
429+ GetTableRequest getTableRequest = getTableRequest (Collections .emptyMap ());
430+ Exception e = assertThrows (IllegalArgumentException .class , () ->
431+ bigQueryMetadataHandler .doGetQueryPassthroughSchema (blockAllocator , getTableRequest ));
432+ assertTrue (e .getMessage ().contains ("No Query passed through" ));
433+ }
291434}
0 commit comments