Skip to content

Commit 5627385

Browse files
Merge AioSparkBaseCursor into AioSparkCursor, remove aio/spark/common.py
Unlike sync Spark (SparkCursor + AsyncSparkCursor sharing SparkBaseCursor), the aio side has only AioSparkCursor, so a separate base class is unnecessary. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3601644 commit 5627385

File tree

2 files changed

+154
-166
lines changed

2 files changed

+154
-166
lines changed

pyathena/aio/spark/common.py

Lines changed: 0 additions & 161 deletions
This file was deleted.

pyathena/aio/spark/cursor.py

Lines changed: 154 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import annotations
33

4+
import asyncio
45
import logging
56
from typing import Any, Dict, List, Optional, Union, cast
67

7-
from pyathena.aio.spark.common import AioSparkBaseCursor
8-
from pyathena.error import OperationalError, ProgrammingError
9-
from pyathena.model import AthenaCalculationExecution, AthenaCalculationExecutionStatus
10-
from pyathena.spark.common import WithCalculationExecution
8+
from pyathena.aio.util import async_retry_api_call
9+
from pyathena.error import NotSupportedError, OperationalError, ProgrammingError
10+
from pyathena.model import (
11+
AthenaCalculationExecution,
12+
AthenaCalculationExecutionStatus,
13+
AthenaQueryExecution,
14+
)
15+
from pyathena.spark.common import SparkBaseCursor, WithCalculationExecution
16+
from pyathena.util import parse_output_location
1117

1218
_logger = logging.getLogger(__name__) # type: ignore
1319

1420

15-
class AioSparkCursor(AioSparkBaseCursor, WithCalculationExecution):
21+
class AioSparkCursor(SparkBaseCursor, WithCalculationExecution):
1622
"""Native asyncio cursor for executing PySpark code on Athena.
1723
24+
Overrides post-init I/O methods of ``SparkBaseCursor`` with async
25+
equivalents. Session management (``_exists_session``,
26+
``_start_session``, etc.) stays synchronous because ``__init__``
27+
runs inside ``asyncio.to_thread``.
28+
1829
Since ``SparkBaseCursor.__init__`` performs I/O (session management),
1930
cursor creation must be wrapped in ``asyncio.to_thread``::
2031
@@ -53,6 +64,132 @@ def __init__(
5364
def calculation_execution(self) -> Optional[AthenaCalculationExecution]:
5465
return self._calculation_execution
5566

67+
# --- async overrides of SparkBaseCursor I/O methods ---
68+
69+
async def _get_calculation_execution_status( # type: ignore[override]
70+
self, query_id: str
71+
) -> AthenaCalculationExecutionStatus:
72+
request: Dict[str, Any] = {"CalculationExecutionId": query_id}
73+
try:
74+
response = await async_retry_api_call(
75+
self._connection.client.get_calculation_execution_status,
76+
config=self._retry_config,
77+
logger=_logger,
78+
**request,
79+
)
80+
except Exception as e:
81+
_logger.exception("Failed to get calculation execution status.")
82+
raise OperationalError(*e.args) from e
83+
else:
84+
return AthenaCalculationExecutionStatus(response)
85+
86+
async def _get_calculation_execution( # type: ignore[override]
87+
self, query_id: str
88+
) -> AthenaCalculationExecution:
89+
request: Dict[str, Any] = {"CalculationExecutionId": query_id}
90+
try:
91+
response = await async_retry_api_call(
92+
self._connection.client.get_calculation_execution,
93+
config=self._retry_config,
94+
logger=_logger,
95+
**request,
96+
)
97+
except Exception as e:
98+
_logger.exception("Failed to get calculation execution.")
99+
raise OperationalError(*e.args) from e
100+
else:
101+
return AthenaCalculationExecution(response)
102+
103+
async def _calculate( # type: ignore[override]
104+
self,
105+
session_id: str,
106+
code_block: str,
107+
description: Optional[str] = None,
108+
client_request_token: Optional[str] = None,
109+
) -> str:
110+
request = self._build_start_calculation_execution_request(
111+
session_id=session_id,
112+
code_block=code_block,
113+
description=description,
114+
client_request_token=client_request_token,
115+
)
116+
try:
117+
response = await async_retry_api_call(
118+
self._connection.client.start_calculation_execution,
119+
config=self._retry_config,
120+
logger=_logger,
121+
**request,
122+
)
123+
calculation_id = response.get("CalculationExecutionId")
124+
except Exception as e:
125+
_logger.exception("Failed to execute calculation.")
126+
raise OperationalError(*e.args) from e
127+
return cast(str, calculation_id)
128+
129+
async def __poll( # type: ignore[override]
130+
self, query_id: str
131+
) -> Union[AthenaQueryExecution, AthenaCalculationExecution]:
132+
while True:
133+
calculation_status = await self._get_calculation_execution_status(query_id)
134+
if calculation_status.state in [
135+
AthenaCalculationExecutionStatus.STATE_COMPLETED,
136+
AthenaCalculationExecutionStatus.STATE_FAILED,
137+
AthenaCalculationExecutionStatus.STATE_CANCELED,
138+
]:
139+
return await self._get_calculation_execution(query_id)
140+
await asyncio.sleep(self._poll_interval)
141+
142+
async def _poll( # type: ignore[override]
143+
self, query_id: str
144+
) -> Union[AthenaQueryExecution, AthenaCalculationExecution]:
145+
try:
146+
query_execution = await self.__poll(query_id)
147+
except asyncio.CancelledError:
148+
if self._kill_on_interrupt:
149+
_logger.warning("Query canceled by user.")
150+
await self._cancel(query_id)
151+
query_execution = await self.__poll(query_id)
152+
else:
153+
raise
154+
return query_execution
155+
156+
async def _cancel(self, query_id: str) -> None: # type: ignore[override]
157+
request: Dict[str, Any] = {"CalculationExecutionId": query_id}
158+
try:
159+
await async_retry_api_call(
160+
self._connection.client.stop_calculation_execution,
161+
config=self._retry_config,
162+
logger=_logger,
163+
**request,
164+
)
165+
except Exception as e:
166+
_logger.exception("Failed to cancel calculation.")
167+
raise OperationalError(*e.args) from e
168+
169+
async def _terminate_session(self) -> None: # type: ignore[override]
170+
request: Dict[str, Any] = {"SessionId": self._session_id}
171+
try:
172+
await async_retry_api_call(
173+
self._connection.client.terminate_session,
174+
config=self._retry_config,
175+
logger=_logger,
176+
**request,
177+
)
178+
except Exception as e:
179+
_logger.exception("Failed to terminate session.")
180+
raise OperationalError(*e.args) from e
181+
182+
async def _read_s3_file_as_text(self, uri) -> str: # type: ignore[override]
183+
bucket, key = parse_output_location(uri)
184+
response = await asyncio.to_thread(
185+
self._client.get_object,
186+
Bucket=bucket,
187+
Key=key,
188+
)
189+
return cast(str, response["Body"].read().decode("utf-8").strip())
190+
191+
# --- public API ---
192+
56193
async def get_std_out(self) -> Optional[str]:
57194
"""Get the standard output from the Spark calculation execution.
58195
@@ -121,6 +258,18 @@ async def cancel(self) -> None:
121258
raise ProgrammingError("CalculationExecutionId is none or empty.")
122259
await self._cancel(self.calculation_id)
123260

261+
async def close(self) -> None: # type: ignore[override]
262+
"""Close the cursor by terminating the Spark session."""
263+
await self._terminate_session()
264+
265+
async def executemany( # type: ignore[override]
266+
self,
267+
operation: str,
268+
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
269+
**kwargs,
270+
) -> None:
271+
raise NotSupportedError
272+
124273
def __aiter__(self):
125274
return self
126275

0 commit comments

Comments
 (0)