Skip to content

Commit 9687211

Browse files
authored
Merge pull request #610 from zen-xu/detect-pkg
refactor: Optimize the detection of whether a package is installed.
2 parents 8a8a4c4 + 1cae814 commit 9687211

File tree

1 file changed

+16
-28
lines changed

1 file changed

+16
-28
lines changed

connectorx-python/connectorx/__init__.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3-
from typing import Any, Literal, TYPE_CHECKING, overload
43

4+
import importlib
55
from importlib.metadata import version
66

7+
from typing import Any, Literal, TYPE_CHECKING, overload
8+
79
from .connectorx import (
810
read_sql as _read_sql,
911
partition_sql as _partition_sql,
@@ -311,10 +313,7 @@ def read_sql(
311313
if return_type == "pandas":
312314
df = df.to_pandas(date_as_object=False, split_blocks=False)
313315
if return_type == "polars":
314-
try:
315-
import polars as pl
316-
except ModuleNotFoundError:
317-
raise ValueError("You need to install polars first")
316+
pl = try_import_module("polars")
318317

319318
try:
320319
# api change for polars >= 0.8.*
@@ -350,10 +349,7 @@ def read_sql(
350349
conn, protocol = rewrite_conn(conn, protocol)
351350

352351
if return_type in {"modin", "dask", "pandas"}:
353-
try:
354-
import pandas
355-
except ModuleNotFoundError:
356-
raise ValueError("You need to install pandas first")
352+
try_import_module("pandas")
357353

358354
result = _read_sql(
359355
conn,
@@ -368,25 +364,14 @@ def read_sql(
368364
df.set_index(index_col, inplace=True)
369365

370366
if return_type == "modin":
371-
try:
372-
import modin.pandas as mpd
373-
except ModuleNotFoundError:
374-
raise ValueError("You need to install modin first")
375-
367+
mpd = try_import_module("modin.pandas")
376368
df = mpd.DataFrame(df)
377369
elif return_type == "dask":
378-
try:
379-
import dask.dataframe as dd
380-
except ModuleNotFoundError:
381-
raise ValueError("You need to install dask first")
382-
370+
dd = try_import_module("dask.dataframe")
383371
df = dd.from_pandas(df, npartitions=1)
384372

385373
elif return_type in {"arrow", "arrow2", "polars", "polars2"}:
386-
try:
387-
import pyarrow
388-
except ModuleNotFoundError:
389-
raise ValueError("You need to install pyarrow first")
374+
try_import_module("pyarrow")
390375

391376
result = _read_sql(
392377
conn,
@@ -397,11 +382,7 @@ def read_sql(
397382
)
398383
df = reconstruct_arrow(result)
399384
if return_type in {"polars", "polars2"}:
400-
try:
401-
import polars as pl
402-
except ModuleNotFoundError:
403-
raise ValueError("You need to install polars first")
404-
385+
pl = try_import_module("polars")
405386
try:
406387
df = pl.DataFrame.from_arrow(df)
407388
except AttributeError:
@@ -488,3 +469,10 @@ def remove_ending_semicolon(query: str) -> str:
488469
if query.endswith(";"):
489470
query = query[:-1]
490471
return query
472+
473+
474+
def try_import_module(name: str):
475+
try:
476+
return importlib.import_module(name)
477+
except ModuleNotFoundError:
478+
raise ValueError(f"You need to install {name.split('.')[0]} first")

0 commit comments

Comments
 (0)