diff --git a/pandas/core/_numba/executor.py b/pandas/core/_numba/executor.py index c2b6191c05152..c9d3a9b1660ea 100644 --- a/pandas/core/_numba/executor.py +++ b/pandas/core/_numba/executor.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing import Callable +from typing import ( + TYPE_CHECKING, + Callable, +) import numpy as np @@ -42,10 +45,12 @@ def generate_shared_aggregator( if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "column_looper" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def column_looper( values: np.ndarray, start: np.ndarray, diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index ea295af8d7910..24d66725caa70 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -3,6 +3,7 @@ import inspect from typing import ( + TYPE_CHECKING, Any, Callable, ) @@ -90,10 +91,12 @@ def generate_numba_agg_func( return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "group_agg" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def group_agg( values: np.ndarray, index: np.ndarray, @@ -152,10 +155,12 @@ def generate_numba_transform_func( return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "group_transform" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def group_transform( values: np.ndarray, index: np.ndarray, diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index c738213d4c487..06630989444bb 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -1,9 +1,11 @@ """Common utilities for Numba operations""" -# pyright: reportUntypedFunctionDecorator = false from __future__ import annotations import types -from typing import Callable +from typing import ( + TYPE_CHECKING, + Callable, +) import numpy as np @@ -84,7 +86,10 @@ def jit_user_function( function Numba JITed function """ - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") if numba.extending.is_jitted(func): # Don't jit a user passed jitted function diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index 74dc104b6db90..0e8eea3ec671e 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -1,8 +1,8 @@ -# pyright: reportUntypedFunctionDecorator = false from __future__ import annotations import functools from typing import ( + TYPE_CHECKING, Any, Callable, ) @@ -56,10 +56,12 @@ def generate_numba_apply_func( return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "roll_apply" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def roll_apply( values: np.ndarray, begin: np.ndarray, @@ -115,10 +117,12 @@ def generate_numba_ewm_func( if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "ewma" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def ewm( values: np.ndarray, begin: np.ndarray, @@ -217,10 +221,12 @@ def generate_numba_table_func( return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "roll_table" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def roll_table( values: np.ndarray, begin: np.ndarray, @@ -250,7 +256,10 @@ def roll_table( # https://github.com/numba/numba/issues/1269 @functools.lru_cache(maxsize=None) def generate_manual_numpy_nan_agg_with_axis(nan_func): - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") @numba.jit(nopython=True, nogil=True, parallel=True) def nan_agg_with_axis(table): @@ -296,10 +305,12 @@ def generate_numba_ewm_table_func( if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "ewm_table" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def ewm_table( values: np.ndarray, begin: np.ndarray, diff --git a/pandas/core/window/online.py b/pandas/core/window/online.py index e8804936da78f..8ef4aee154db4 100644 --- a/pandas/core/window/online.py +++ b/pandas/core/window/online.py @@ -1,4 +1,5 @@ from typing import ( + TYPE_CHECKING, Dict, Optional, ) @@ -31,10 +32,12 @@ def generate_online_numba_ewma_func(engine_kwargs: Optional[Dict[str, bool]]): if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "online_ewma" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def online_ewma( values: np.ndarray, deltas: np.ndarray, diff --git a/typings/numba.pyi b/typings/numba.pyi index 526119951a000..f877cbf339a8b 100644 --- a/typings/numba.pyi +++ b/typings/numba.pyi @@ -40,3 +40,4 @@ def jit( ) -> Callable[[F], F]: ... njit = jit +generated_jit = jit