11from __future__ import annotations
22
3- from typing import Any
3+ from typing import Any , Literal , TYPE_CHECKING , overload
44
55from importlib .metadata import version
66
1111 get_meta as _get_meta ,
1212)
1313
14+ if TYPE_CHECKING :
15+ import pandas as pd
16+ import polars as pl
17+ import modin .pandas as mpd
18+ import dask .dataframe as dd
19+ import pyarrow as pa
20+
1421__version__ = version (__name__ )
1522
1623import os
2734 "CX_REWRITER_PATH" , os .path .join (dir_path , "dependencies/federated-rewriter.jar" )
2835)
2936
37+ Protocol = Literal ["csv" , "binary" , "cursor" , "simple" , "text" ]
3038
31- def rewrite_conn (conn : str , protocol : str | None = None ):
39+
40+ def rewrite_conn (conn : str , protocol : Protocol | None = None ) -> tuple [str , Protocol ]:
3241 if not protocol :
3342 # note: redshift/clickhouse are not compatible with the 'binary' protocol, and use other database
3443 # drivers to connect. set a compatible protocol and masquerade as the appropriate backend.
@@ -47,8 +56,8 @@ def rewrite_conn(conn: str, protocol: str | None = None):
4756def get_meta (
4857 conn : str ,
4958 query : str ,
50- protocol : str | None = None ,
51- ):
59+ protocol : Protocol | None = None ,
60+ ) -> pd . DataFrame :
5261 """
5362 Get metadata (header) of the given query (only for pandas)
5463
@@ -75,7 +84,7 @@ def partition_sql(
7584 partition_on : str ,
7685 partition_num : int ,
7786 partition_range : tuple [int , int ] | None = None ,
78- ):
87+ ) -> list [ str ] :
7988 """
8089 Partition the sql query
8190
@@ -106,11 +115,11 @@ def read_sql_pandas(
106115 sql : list [str ] | str ,
107116 con : str | dict [str , str ],
108117 index_col : str | None = None ,
109- protocol : str | None = None ,
118+ protocol : Protocol | None = None ,
110119 partition_on : str | None = None ,
111120 partition_range : tuple [int , int ] | None = None ,
112121 partition_num : int | None = None ,
113- ):
122+ ) -> pd . DataFrame :
114123 """
115124 Run the SQL query, download the data from database into a dataframe.
116125 First several parameters are in the same name and order with `pandas.read_sql`.
@@ -142,17 +151,103 @@ def read_sql_pandas(
142151 )
143152
144153
154+ # default return pd.DataFrame
155+ @overload
145156def read_sql (
146157 conn : str | dict [str , str ],
147158 query : list [str ] | str ,
148159 * ,
149- return_type : str = "pandas" ,
150- protocol : str | None = None ,
160+ protocol : Protocol | None = None ,
151161 partition_on : str | None = None ,
152162 partition_range : tuple [int , int ] | None = None ,
153163 partition_num : int | None = None ,
154164 index_col : str | None = None ,
155- ):
165+ ) -> pd .DataFrame : ...
166+
167+
168+ @overload
169+ def read_sql (
170+ conn : str | dict [str , str ],
171+ query : list [str ] | str ,
172+ * ,
173+ return_type : Literal ["pandas" ],
174+ protocol : Protocol | None = None ,
175+ partition_on : str | None = None ,
176+ partition_range : tuple [int , int ] | None = None ,
177+ partition_num : int | None = None ,
178+ index_col : str | None = None ,
179+ ) -> pd .DataFrame : ...
180+
181+
182+ @overload
183+ def read_sql (
184+ conn : str | dict [str , str ],
185+ query : list [str ] | str ,
186+ * ,
187+ return_type : Literal ["arrow" , "arrow2" ],
188+ protocol : Protocol | None = None ,
189+ partition_on : str | None = None ,
190+ partition_range : tuple [int , int ] | None = None ,
191+ partition_num : int | None = None ,
192+ index_col : str | None = None ,
193+ ) -> pa .Table : ...
194+
195+
196+ @overload
197+ def read_sql (
198+ conn : str | dict [str , str ],
199+ query : list [str ] | str ,
200+ * ,
201+ return_type : Literal ["modin" ],
202+ protocol : Protocol | None = None ,
203+ partition_on : str | None = None ,
204+ partition_range : tuple [int , int ] | None = None ,
205+ partition_num : int | None = None ,
206+ index_col : str | None = None ,
207+ ) -> mpd .DataFrame : ...
208+
209+
210+ @overload
211+ def read_sql (
212+ conn : str | dict [str , str ],
213+ query : list [str ] | str ,
214+ * ,
215+ return_type : Literal ["dask" ],
216+ protocol : Protocol | None = None ,
217+ partition_on : str | None = None ,
218+ partition_range : tuple [int , int ] | None = None ,
219+ partition_num : int | None = None ,
220+ index_col : str | None = None ,
221+ ) -> dd .DataFrame : ...
222+
223+
224+ @overload
225+ def read_sql (
226+ conn : str | dict [str , str ],
227+ query : list [str ] | str ,
228+ * ,
229+ return_type : Literal ["polars" , "polars2" ],
230+ protocol : Protocol | None = None ,
231+ partition_on : str | None = None ,
232+ partition_range : tuple [int , int ] | None = None ,
233+ partition_num : int | None = None ,
234+ index_col : str | None = None ,
235+ ) -> pl .DataFrame : ...
236+
237+
238+ def read_sql (
239+ conn : str | dict [str , str ],
240+ query : list [str ] | str ,
241+ * ,
242+ return_type : Literal [
243+ "pandas" , "polars" , "polars2" , "arrow" , "arrow2" , "modin" , "dask"
244+ ] = "pandas" ,
245+ protocol : Protocol | None = None ,
246+ partition_on : str | None = None ,
247+ partition_range : tuple [int , int ] | None = None ,
248+ partition_num : int | None = None ,
249+ index_col : str | None = None ,
250+ ) -> pd .DataFrame | mpd .DataFrame | dd .DataFrame | pl .DataFrame | pa .Table :
156251 """
157252 Run the SQL query, download the data from database into a dataframe.
158253
@@ -318,7 +413,9 @@ def read_sql(
318413 return df
319414
320415
321- def reconstruct_arrow (result : tuple [list [str ], list [list [tuple [int , int ]]]]):
416+ def reconstruct_arrow (
417+ result : tuple [list [str ], list [list [tuple [int , int ]]]],
418+ ) -> pa .Table :
322419 import pyarrow as pa
323420
324421 names , ptrs = result
@@ -334,7 +431,7 @@ def reconstruct_arrow(result: tuple[list[str], list[list[tuple[int, int]]]]):
334431 return pa .Table .from_batches (rbs )
335432
336433
337- def reconstruct_pandas (df_infos : dict [str , Any ]):
434+ def reconstruct_pandas (df_infos : dict [str , Any ]) -> pd . DataFrame :
338435 import pandas as pd
339436
340437 data = df_infos ["data" ]
@@ -388,6 +485,6 @@ def remove_ending_semicolon(query: str) -> str:
388485 SQL query
389486
390487 """
391- if query .endswith (';' ):
488+ if query .endswith (";" ):
392489 query = query [:- 1 ]
393490 return query
0 commit comments