55import logging
66import sys
77from datetime import datetime , timedelta , timezone
8- from typing import Any , Dict , List , Optional , Tuple , Union , cast
8+ from typing import Any , Dict , List , Optional , Tuple , Union
99
10- import pyathena
1110from pyathena .aio .util import async_retry_api_call
1211from pyathena .common import BaseCursor
1312from pyathena .error import DatabaseError , OperationalError
14- from pyathena .model import AthenaQueryExecution
13+ from pyathena .model import AthenaDatabase , AthenaQueryExecution , AthenaTableMetadata
1514
1615_logger = logging .getLogger (__name__ ) # type: ignore
1716
@@ -36,13 +35,7 @@ async def _execute( # type: ignore[override]
3635 result_reuse_minutes : Optional [int ] = None ,
3736 paramstyle : Optional [str ] = None ,
3837 ) -> str :
39- if pyathena .paramstyle == "qmark" or paramstyle == "qmark" :
40- query = operation
41- execution_parameters = cast (Optional [List [str ]], parameters )
42- else :
43- query = self ._formatter .format (operation , cast (Optional [Dict [str , Any ]], parameters ))
44- execution_parameters = None
45- _logger .debug (query )
38+ query , execution_parameters = self ._prepare_query (operation , parameters , paramstyle )
4639
4740 request = self ._build_start_query_execution_request (
4841 query = query ,
@@ -216,3 +209,142 @@ async def _find_previous_query_id( # type: ignore[override]
216209 except Exception :
217210 _logger .warning ("Failed to check the cache. Moving on without cache." , exc_info = True )
218211 return query_id
212+
213+ async def _list_databases ( # type: ignore[override]
214+ self ,
215+ catalog_name : Optional [str ],
216+ next_token : Optional [str ] = None ,
217+ max_results : Optional [int ] = None ,
218+ ) -> Tuple [Optional [str ], List [AthenaDatabase ]]:
219+ request = self ._build_list_databases_request (
220+ catalog_name = catalog_name ,
221+ next_token = next_token ,
222+ max_results = max_results ,
223+ )
224+ try :
225+ response = await async_retry_api_call (
226+ self .connection ._client .list_databases ,
227+ config = self ._retry_config ,
228+ logger = _logger ,
229+ ** request ,
230+ )
231+ except Exception as e :
232+ _logger .exception ("Failed to list databases." )
233+ raise OperationalError (* e .args ) from e
234+ else :
235+ return response .get ("NextToken" ), [
236+ AthenaDatabase ({"Database" : r }) for r in response .get ("DatabaseList" , [])
237+ ]
238+
239+ async def list_databases ( # type: ignore[override]
240+ self ,
241+ catalog_name : Optional [str ],
242+ max_results : Optional [int ] = None ,
243+ ) -> List [AthenaDatabase ]:
244+ databases : List [AthenaDatabase ] = []
245+ next_token = None
246+ while True :
247+ next_token , response = await self ._list_databases (
248+ catalog_name = catalog_name ,
249+ next_token = next_token ,
250+ max_results = max_results ,
251+ )
252+ databases .extend (response )
253+ if not next_token :
254+ break
255+ return databases
256+
257+ async def _get_table_metadata ( # type: ignore[override]
258+ self ,
259+ table_name : str ,
260+ catalog_name : Optional [str ] = None ,
261+ schema_name : Optional [str ] = None ,
262+ logging_ : bool = True ,
263+ ) -> AthenaTableMetadata :
264+ request : Dict [str , Any ] = {
265+ "CatalogName" : catalog_name if catalog_name else self ._catalog_name ,
266+ "DatabaseName" : schema_name if schema_name else self ._schema_name ,
267+ "TableName" : table_name ,
268+ }
269+ if self ._work_group :
270+ request .update ({"WorkGroup" : self ._work_group })
271+ try :
272+ response = await async_retry_api_call (
273+ self ._connection .client .get_table_metadata ,
274+ config = self ._retry_config ,
275+ logger = _logger ,
276+ ** request ,
277+ )
278+ except Exception as e :
279+ if logging_ :
280+ _logger .exception ("Failed to get table metadata." )
281+ raise OperationalError (* e .args ) from e
282+ else :
283+ return AthenaTableMetadata (response )
284+
285+ async def get_table_metadata ( # type: ignore[override]
286+ self ,
287+ table_name : str ,
288+ catalog_name : Optional [str ] = None ,
289+ schema_name : Optional [str ] = None ,
290+ logging_ : bool = True ,
291+ ) -> AthenaTableMetadata :
292+ return await self ._get_table_metadata (
293+ table_name = table_name ,
294+ catalog_name = catalog_name ,
295+ schema_name = schema_name ,
296+ logging_ = logging_ ,
297+ )
298+
299+ async def _list_table_metadata ( # type: ignore[override]
300+ self ,
301+ catalog_name : Optional [str ] = None ,
302+ schema_name : Optional [str ] = None ,
303+ expression : Optional [str ] = None ,
304+ next_token : Optional [str ] = None ,
305+ max_results : Optional [int ] = None ,
306+ ) -> Tuple [Optional [str ], List [AthenaTableMetadata ]]:
307+ request = self ._build_list_table_metadata_request (
308+ catalog_name = catalog_name ,
309+ schema_name = schema_name ,
310+ expression = expression ,
311+ next_token = next_token ,
312+ max_results = max_results ,
313+ )
314+ try :
315+ response = await async_retry_api_call (
316+ self .connection ._client .list_table_metadata ,
317+ config = self ._retry_config ,
318+ logger = _logger ,
319+ ** request ,
320+ )
321+ except Exception as e :
322+ _logger .exception ("Failed to list table metadata." )
323+ raise OperationalError (* e .args ) from e
324+ else :
325+ return response .get ("NextToken" ), [
326+ AthenaTableMetadata ({"TableMetadata" : r })
327+ for r in response .get ("TableMetadataList" , [])
328+ ]
329+
330+ async def list_table_metadata ( # type: ignore[override]
331+ self ,
332+ catalog_name : Optional [str ] = None ,
333+ schema_name : Optional [str ] = None ,
334+ expression : Optional [str ] = None ,
335+ max_results : Optional [int ] = None ,
336+ ) -> List [AthenaTableMetadata ]:
337+ metadata : List [AthenaTableMetadata ] = []
338+ next_token = None
339+ while True :
340+ next_token , response = await self ._list_table_metadata (
341+ catalog_name = catalog_name ,
342+ schema_name = schema_name ,
343+ expression = expression ,
344+ next_token = next_token ,
345+ max_results = max_results ,
346+ )
347+ metadata .extend (response )
348+ if not next_token :
349+ break
350+ return metadata
0 commit comments