Skip to content

Commit 6e273c0

Browse files
Extract _prepare_unload() helper into BaseCursor
The UNLOAD preparation block (validate s3_staging_dir, wrap query with UNLOAD statement) was duplicated character-for-character across 9 cursor files. This consolidates the logic into a single _prepare_unload() method on BaseCursor, which is pure computation and works for all cursor families (sync, async, aio). Also adds return type annotation to Formatter.wrap_unload() to satisfy mypy. Closes #670 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b753495 commit 6e273c0

File tree

11 files changed

+48
-119
lines changed

11 files changed

+48
-119
lines changed

pyathena/aio/arrow/cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pyathena.arrow.result_set import AthenaArrowResultSet
1414
from pyathena.common import CursorIterator
1515
from pyathena.error import OperationalError, ProgrammingError
16-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
16+
from pyathena.model import AthenaQueryExecution
1717

1818
if TYPE_CHECKING:
1919
import polars as pl
@@ -110,18 +110,7 @@ async def execute( # type: ignore[override]
110110
Self reference for method chaining.
111111
"""
112112
self._reset_state()
113-
if self._unload:
114-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
115-
if not s3_staging_dir:
116-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
117-
operation, unload_location = self._formatter.wrap_unload(
118-
operation,
119-
s3_staging_dir=s3_staging_dir,
120-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
121-
compression=AthenaCompression.COMPRESSION_SNAPPY,
122-
)
123-
else:
124-
unload_location = None
113+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
125114
self.query_id = await self._execute(
126115
operation,
127116
parameters=parameters,

pyathena/aio/pandas/cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pyathena.aio.common import WithAsyncFetch
1919
from pyathena.common import CursorIterator
2020
from pyathena.error import OperationalError, ProgrammingError
21-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
21+
from pyathena.model import AthenaQueryExecution
2222
from pyathena.pandas.converter import (
2323
DefaultPandasTypeConverter,
2424
DefaultPandasUnloadTypeConverter,
@@ -133,18 +133,7 @@ async def execute( # type: ignore[override]
133133
Self reference for method chaining.
134134
"""
135135
self._reset_state()
136-
if self._unload:
137-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
138-
if not s3_staging_dir:
139-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
140-
operation, unload_location = self._formatter.wrap_unload(
141-
operation,
142-
s3_staging_dir=s3_staging_dir,
143-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
144-
compression=AthenaCompression.COMPRESSION_SNAPPY,
145-
)
146-
else:
147-
unload_location = None
136+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
148137
self.query_id = await self._execute(
149138
operation,
150139
parameters=parameters,

pyathena/aio/polars/cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pyathena.aio.common import WithAsyncFetch
1010
from pyathena.common import CursorIterator
1111
from pyathena.error import OperationalError, ProgrammingError
12-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
12+
from pyathena.model import AthenaQueryExecution
1313
from pyathena.polars.converter import (
1414
DefaultPolarsTypeConverter,
1515
DefaultPolarsUnloadTypeConverter,
@@ -115,18 +115,7 @@ async def execute( # type: ignore[override]
115115
Self reference for method chaining.
116116
"""
117117
self._reset_state()
118-
if self._unload:
119-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
120-
if not s3_staging_dir:
121-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
122-
operation, unload_location = self._formatter.wrap_unload(
123-
operation,
124-
s3_staging_dir=s3_staging_dir,
125-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
126-
compression=AthenaCompression.COMPRESSION_SNAPPY,
127-
)
128-
else:
129-
unload_location = None
118+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
130119
self.query_id = await self._execute(
131120
operation,
132121
parameters=parameters,

pyathena/arrow/async_cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pyathena.arrow.result_set import AthenaArrowResultSet
1515
from pyathena.async_cursor import AsyncCursor
1616
from pyathena.common import CursorIterator
17-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
17+
from pyathena.model import AthenaQueryExecution
1818

1919
_logger = logging.getLogger(__name__) # type: ignore
2020

@@ -182,18 +182,7 @@ def execute(
182182
paramstyle: Optional[str] = None,
183183
**kwargs,
184184
) -> Tuple[str, "Future[Union[AthenaArrowResultSet, Any]]"]:
185-
if self._unload:
186-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
187-
if not s3_staging_dir:
188-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
189-
operation, unload_location = self._formatter.wrap_unload(
190-
operation,
191-
s3_staging_dir=s3_staging_dir,
192-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
193-
compression=AthenaCompression.COMPRESSION_SNAPPY,
194-
)
195-
else:
196-
unload_location = None
185+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
197186
query_id = self._execute(
198187
operation,
199188
parameters=parameters,

pyathena/arrow/cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pyathena.arrow.result_set import AthenaArrowResultSet
1212
from pyathena.common import CursorIterator
1313
from pyathena.error import OperationalError, ProgrammingError
14-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
14+
from pyathena.model import AthenaQueryExecution
1515
from pyathena.result_set import WithFetch
1616

1717
if TYPE_CHECKING:
@@ -166,18 +166,7 @@ def execute(
166166
>>> table = cursor.as_arrow() # Returns Apache Arrow Table
167167
"""
168168
self._reset_state()
169-
if self._unload:
170-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
171-
if not s3_staging_dir:
172-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
173-
operation, unload_location = self._formatter.wrap_unload(
174-
operation,
175-
s3_staging_dir=s3_staging_dir,
176-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
177-
compression=AthenaCompression.COMPRESSION_SNAPPY,
178-
)
179-
else:
180-
unload_location = None
169+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
181170
self.query_id = self._execute(
182171
operation,
183172
parameters=parameters,

pyathena/common.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from pyathena.model import (
1616
AthenaCalculationExecution,
1717
AthenaCalculationExecutionStatus,
18+
AthenaCompression,
1819
AthenaDatabase,
20+
AthenaFileFormat,
1921
AthenaQueryExecution,
2022
AthenaTableMetadata,
2123
)
@@ -652,6 +654,32 @@ def _prepare_query(
652654
_logger.debug(query)
653655
return query, execution_parameters
654656

657+
def _prepare_unload(
658+
self,
659+
operation: str,
660+
s3_staging_dir: Optional[str],
661+
) -> Tuple[str, Optional[str]]:
662+
"""Wrap operation with UNLOAD if enabled.
663+
664+
Args:
665+
operation: SQL query string.
666+
s3_staging_dir: S3 location for query results.
667+
668+
Returns:
669+
Tuple of (possibly-wrapped operation, unload_location or None).
670+
"""
671+
if not getattr(self, "_unload", False):
672+
return operation, None
673+
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
674+
if not s3_staging_dir:
675+
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
676+
return self._formatter.wrap_unload(
677+
operation,
678+
s3_staging_dir=s3_staging_dir,
679+
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
680+
compression=AthenaCompression.COMPRESSION_SNAPPY,
681+
)
682+
655683
def _execute(
656684
self,
657685
operation: str,

pyathena/formatter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from copy import deepcopy
99
from datetime import date, datetime, timezone
1010
from decimal import Decimal
11-
from typing import Any, Callable, Dict, Optional, Type
11+
from typing import Any, Callable, Dict, Optional, Tuple, Type
1212

1313
from pyathena.error import ProgrammingError
1414
from pyathena.model import AthenaCompression, AthenaFileFormat
@@ -86,7 +86,7 @@ def wrap_unload(
8686
s3_staging_dir: str,
8787
format_: str = AthenaFileFormat.FILE_FORMAT_PARQUET,
8888
compression: str = AthenaCompression.COMPRESSION_SNAPPY,
89-
):
89+
) -> Tuple[str, Optional[str]]:
9090
"""Wrap a SELECT query with UNLOAD statement for high-performance result retrieval.
9191
9292
Transforms SELECT or WITH queries into UNLOAD statements that export results

pyathena/pandas/async_cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pyathena import ProgrammingError
1010
from pyathena.async_cursor import AsyncCursor
1111
from pyathena.common import CursorIterator
12-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
12+
from pyathena.model import AthenaQueryExecution
1313
from pyathena.pandas.converter import (
1414
DefaultPandasTypeConverter,
1515
DefaultPandasUnloadTypeConverter,
@@ -159,18 +159,7 @@ def execute(
159159
quoting: int = 1,
160160
**kwargs,
161161
) -> Tuple[str, "Future[Union[AthenaPandasResultSet, Any]]"]:
162-
if self._unload:
163-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
164-
if not s3_staging_dir:
165-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
166-
operation, unload_location = self._formatter.wrap_unload(
167-
operation,
168-
s3_staging_dir=s3_staging_dir,
169-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
170-
compression=AthenaCompression.COMPRESSION_SNAPPY,
171-
)
172-
else:
173-
unload_location = None
162+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
174163
query_id = self._execute(
175164
operation,
176165
parameters=parameters,

pyathena/pandas/cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pyathena.common import CursorIterator
2020
from pyathena.error import OperationalError, ProgrammingError
21-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
21+
from pyathena.model import AthenaQueryExecution
2222
from pyathena.pandas.converter import (
2323
DefaultPandasTypeConverter,
2424
DefaultPandasUnloadTypeConverter,
@@ -193,18 +193,7 @@ def execute(
193193
>>> df = cursor.fetchall() # Returns pandas DataFrame
194194
"""
195195
self._reset_state()
196-
if self._unload:
197-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
198-
if not s3_staging_dir:
199-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
200-
operation, unload_location = self._formatter.wrap_unload(
201-
operation,
202-
s3_staging_dir=s3_staging_dir,
203-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
204-
compression=AthenaCompression.COMPRESSION_SNAPPY,
205-
)
206-
else:
207-
unload_location = None
196+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
208197
self.query_id = self._execute(
209198
operation,
210199
parameters=parameters,

pyathena/polars/async_cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pyathena import ProgrammingError
1010
from pyathena.async_cursor import AsyncCursor
1111
from pyathena.common import CursorIterator
12-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
12+
from pyathena.model import AthenaQueryExecution
1313
from pyathena.polars.converter import (
1414
DefaultPolarsTypeConverter,
1515
DefaultPolarsUnloadTypeConverter,
@@ -221,18 +221,7 @@ def execute(
221221
>>> result_set = future.result()
222222
>>> df = result_set.as_polars() # Returns Polars DataFrame
223223
"""
224-
if self._unload:
225-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
226-
if not s3_staging_dir:
227-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
228-
operation, unload_location = self._formatter.wrap_unload(
229-
operation,
230-
s3_staging_dir=s3_staging_dir,
231-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
232-
compression=AthenaCompression.COMPRESSION_SNAPPY,
233-
)
234-
else:
235-
unload_location = None
224+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
236225
query_id = self._execute(
237226
operation,
238227
parameters=parameters,

0 commit comments

Comments
 (0)