From 8d2026ace8c848b9fe5c6d55cce3ba260f778174 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Mon, 1 Nov 2021 18:31:15 +0000 Subject: [PATCH 1/3] TYP: use TYPE_CHECKING for import_optional_dependency("numba") --- pandas/core/_numba/executor.py | 13 +++++++++---- pandas/core/groupby/numba_.py | 17 +++++++++++------ pandas/core/window/numba_.py | 33 +++++++++++++++++++++------------ pandas/core/window/online.py | 9 ++++++--- 4 files changed, 47 insertions(+), 25 deletions(-) diff --git a/pandas/core/_numba/executor.py b/pandas/core/_numba/executor.py index c2b6191c05152..c9d3a9b1660ea 100644 --- a/pandas/core/_numba/executor.py +++ b/pandas/core/_numba/executor.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing import Callable +from typing import ( + TYPE_CHECKING, + Callable, +) import numpy as np @@ -42,10 +45,12 @@ def generate_shared_aggregator( if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "column_looper" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def column_looper( values: np.ndarray, start: np.ndarray, diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index ea295af8d7910..24d66725caa70 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -3,6 +3,7 @@ import inspect from typing import ( + TYPE_CHECKING, Any, Callable, ) @@ -90,10 +91,12 @@ def generate_numba_agg_func( return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "group_agg" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def group_agg( values: np.ndarray, index: np.ndarray, @@ -152,10 +155,12 @@ def generate_numba_transform_func( return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "group_transform" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def group_transform( values: np.ndarray, index: np.ndarray, diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index 74dc104b6db90..b753b95300586 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -3,6 +3,7 @@ import functools from typing import ( + TYPE_CHECKING, Any, Callable, ) @@ -56,10 +57,12 @@ def generate_numba_apply_func( return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "roll_apply" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def roll_apply( values: np.ndarray, begin: np.ndarray, @@ -115,10 +118,12 @@ def generate_numba_ewm_func( if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "ewma" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def ewm( values: np.ndarray, begin: np.ndarray, @@ -217,10 +222,12 @@ def generate_numba_table_func( return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "roll_table" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def roll_table( values: np.ndarray, begin: np.ndarray, @@ -296,10 +303,12 @@ def generate_numba_ewm_table_func( if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "ewm_table" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def ewm_table( values: np.ndarray, begin: np.ndarray, diff --git a/pandas/core/window/online.py b/pandas/core/window/online.py index e8804936da78f..8ef4aee154db4 100644 --- a/pandas/core/window/online.py +++ b/pandas/core/window/online.py @@ -1,4 +1,5 @@ from typing import ( + TYPE_CHECKING, Dict, Optional, ) @@ -31,10 +32,12 @@ def generate_online_numba_ewma_func(engine_kwargs: Optional[Dict[str, bool]]): if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") - # error: Untyped decorator makes function "online_ewma" untyped - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc] + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def online_ewma( values: np.ndarray, deltas: np.ndarray, From 48d32695f25d469e5be1dd5937d6944b23682088 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Tue, 2 Nov 2021 10:17:48 +0000 Subject: [PATCH 2/3] remove reportUntypedFunctionDecorator = false --- pandas/core/window/numba_.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index b753b95300586..0e8eea3ec671e 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -1,4 +1,3 @@ -# pyright: reportUntypedFunctionDecorator = false from __future__ import annotations import functools @@ -257,7 +256,10 @@ def roll_table( # https://github.com/numba/numba/issues/1269 @functools.lru_cache(maxsize=None) def generate_manual_numpy_nan_agg_with_axis(nan_func): - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") @numba.jit(nopython=True, nogil=True, parallel=True) def nan_agg_with_axis(table): From 751453410aa038ba69d3ff25dc0902aba0f59a05 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Tue, 2 Nov 2021 10:50:44 +0000 Subject: [PATCH 3/3] generated_jit --- pandas/core/util/numba_.py | 11 ++++++++--- typings/numba.pyi | 1 + 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index c738213d4c487..06630989444bb 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -1,9 +1,11 @@ """Common utilities for Numba operations""" -# pyright: reportUntypedFunctionDecorator = false from __future__ import annotations import types -from typing import Callable +from typing import ( + TYPE_CHECKING, + Callable, +) import numpy as np @@ -84,7 +86,10 @@ def jit_user_function( function Numba JITed function """ - numba = import_optional_dependency("numba") + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") if numba.extending.is_jitted(func): # Don't jit a user passed jitted function diff --git a/typings/numba.pyi b/typings/numba.pyi index 526119951a000..f877cbf339a8b 100644 --- a/typings/numba.pyi +++ b/typings/numba.pyi @@ -40,3 +40,4 @@ def jit( ) -> Callable[[F], F]: ... njit = jit +generated_jit = jit