Skip to content

Some numba backend fixes #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
import sys
import warnings
from contextlib import contextmanager
from functools import singledispatch
Expand All @@ -10,7 +11,7 @@
import numpy as np
import scipy
import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type
from llvmlite import ir
from numba import types
from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
Expand Down Expand Up @@ -48,10 +49,13 @@

def numba_njit(*args, **kwargs):

kwargs = kwargs.copy()
kwargs.setdefault("cache", config.numba__cache)

if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], cache=config.numba__cache, **kwargs)(args[0])
return numba.njit(*args[1:], **kwargs)(args[0])

return numba.njit(*args, cache=config.numba__cache, **kwargs)
return numba.njit(*args, **kwargs)


def numba_vectorize(*args, **kwargs):
Expand Down Expand Up @@ -128,7 +132,7 @@ def create_numba_signature(


def slice_new(self, start, stop, step):
fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
fn = self._get_function(fnty, name="PySlice_New")
return self.builder.call(fn, [start, stop, step])

Expand All @@ -147,11 +151,33 @@ def box_slice(typ, val, c):
This makes it possible to return an Numba's internal representation of a
``slice`` object as a proper ``slice`` to Python.
"""
start = c.builder.extract_value(val, 0)
stop = c.builder.extract_value(val, 1)

none_val = ir.Constant(ir.IntType(64), sys.maxsize)

start_is_none = c.builder.icmp_signed("==", start, none_val)
start = c.builder.select(
start_is_none,
c.pyapi.get_null_object(),
c.box(types.int64, start),
)

stop_is_none = c.builder.icmp_signed("==", stop, none_val)
stop = c.builder.select(
stop_is_none,
c.pyapi.get_null_object(),
c.box(types.int64, stop),
)

start = c.box(types.int64, c.builder.extract_value(val, 0))
stop = c.box(types.int64, c.builder.extract_value(val, 1))
if typ.has_step:
step = c.box(types.int64, c.builder.extract_value(val, 2))
step = c.builder.extract_value(val, 2)
step_is_none = c.builder.icmp_signed("==", step, none_val)
step = c.builder.select(
step_is_none,
c.pyapi.get_null_object(),
c.box(types.int64, step),
)
else:
step = c.pyapi.get_null_object()

Expand Down Expand Up @@ -319,9 +345,8 @@ def numba_typify(data, dtype=None, **kwargs):
return data


@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an PyTensor `Op`."""
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""

warnings.warn(
f"Numba will use object mode to run {op}'s perform method",
Expand Down Expand Up @@ -375,6 +400,12 @@ def perform(*inputs):
return perform


@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Generate a numba function for a given op and apply node."""
return generate_fallback_impl(op, node, storage_map, **kwargs)


@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):

Expand Down Expand Up @@ -506,7 +537,6 @@ def {fn_name}({", ".join(input_names)}):


@numba_funcify.register(Subtensor)
@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs):

Expand All @@ -524,7 +554,6 @@ def numba_funcify_Subtensor(op, node, **kwargs):


@numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
def numba_funcify_IncSubtensor(op, node, **kwargs):

incsubtensor_def_src = create_index_func(
Expand Down
211 changes: 211 additions & 0 deletions pytensor/link/numba/dispatch/cython_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import ctypes
import importlib
import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast

import numba
import numpy as np
from numpy.typing import DTypeLike
from scipy import LowLevelCallable


_C_TO_NUMPY: Dict[str, DTypeLike] = {
"bool": np.bool_,
"signed char": np.byte,
"unsigned char": np.ubyte,
"short": np.short,
"unsigned short": np.ushort,
"int": np.intc,
"unsigned int": np.uintc,
"long": np.int_,
"unsigned long": np.uint,
"long long": np.longlong,
"float": np.single,
"double": np.double,
"long double": np.longdouble,
"float complex": np.csingle,
"double complex": np.cdouble,
}


@dataclass
class Signature:
res_dtype: DTypeLike
Copy link
Member

@ferrine ferrine Dec 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it a single output signature?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this currently only supports scalar return values. This is all we need for the scipy.special functions, but if we need more later (maybe for some linalg stuff?) we might have to expand it a bit.

res_c_type: str
arg_dtypes: List[DTypeLike]
arg_c_types: List[str]
arg_names: List[Optional[str]]

@property
def arg_numba_types(self) -> List[DTypeLike]:
return [numba.from_dtype(dtype) for dtype in self.arg_dtypes]

def can_cast_args(self, args: List[DTypeLike]) -> bool:
ok = True
count = 0
for name, dtype in zip(self.arg_names, self.arg_dtypes):
if name == "__pyx_skip_dispatch":
continue
if len(args) <= count:
raise ValueError("Incorrect number of arguments")
ok &= np.can_cast(args[count], dtype)
count += 1
if count != len(args):
return False
return ok

def provides(self, restype: DTypeLike, arg_dtypes: List[DTypeLike]) -> bool:
args_ok = self.can_cast_args(arg_dtypes)
if np.issubdtype(restype, np.inexact):
result_ok = np.can_cast(self.res_dtype, restype, casting="same_kind")
# We do not want to provide less accuracy than advertised
result_ok &= np.dtype(self.res_dtype).itemsize >= np.dtype(restype).itemsize
else:
result_ok = np.can_cast(self.res_dtype, restype)
return args_ok and result_ok

@staticmethod
def from_c_types(signature: bytes) -> "Signature":
# Match strings like "double(int, double)"
# and extract the return type and the joined arguments
expr = re.compile(rb"\s*(?P<restype>[\w ]*\w+)\s*\((?P<args>[\w\s,]*)\)")
re_match = re.fullmatch(expr, signature)

if re_match is None:
raise ValueError(f"Invalid signature: {signature.decode()}")

groups = re_match.groupdict()
res_c_type = groups["restype"].decode()
res_dtype: DTypeLike = _C_TO_NUMPY[res_c_type]

raw_args = groups["args"]

decl_expr = re.compile(
rb"\s*(?P<type>((long )|(unsigned )|(signed )|(double )|)"
rb"((double)|(float)|(int)|(short)|(char)|(long)|(bool)|(complex)))"
rb"(\s(?P<name>[\w_]*))?\s*"
)

arg_dtypes = []
arg_names: List[Optional[str]] = []
arg_c_types = []
for raw_arg in raw_args.split(b","):
re_match = re.fullmatch(decl_expr, raw_arg)
if re_match is None:
raise ValueError(f"Invalid signature: {signature.decode()}")
groups = re_match.groupdict()
arg_c_type = groups["type"].decode()
try:
arg_dtype = _C_TO_NUMPY[arg_c_type]
except KeyError:
raise ValueError(f"Unknown C type: {arg_c_type}")

arg_c_types.append(arg_c_type)
arg_dtypes.append(arg_dtype)
name = groups["name"]
if not name:
arg_names.append(None)
else:
arg_names.append(name.decode())

return Signature(res_dtype, res_c_type, arg_dtypes, arg_c_types, arg_names)


def _available_impls(func: Callable) -> List[Tuple[Signature, Any]]:
"""Find all available implementations for a fused cython function."""
impls = []
mod = importlib.import_module(func.__module__)

signatures = getattr(func, "__signatures__", None)
if signatures is not None:
# Cython function with __signatures__ should be fused and thus
# indexable
func_map = cast(Mapping, func)
candidates = [func_map[key] for key in signatures]
else:
candidates = [func]
for candidate in candidates:
name = candidate.__name__
capsule = mod.__pyx_capi__[name]
llc = LowLevelCallable(capsule)
try:
signature = Signature.from_c_types(llc.signature.encode())
except KeyError:
continue
impls.append((signature, capsule))
return impls


class _CythonWrapper(numba.types.WrapperAddressProtocol):
def __init__(self, pyfunc, signature, capsule):
self._keep_alive = capsule
get_name = ctypes.pythonapi.PyCapsule_GetName
get_name.restype = ctypes.c_char_p
get_name.argtypes = (ctypes.py_object,)

raw_signature = get_name(capsule)

get_pointer = ctypes.pythonapi.PyCapsule_GetPointer
get_pointer.restype = ctypes.c_void_p
get_pointer.argtypes = (ctypes.py_object, ctypes.c_char_p)
self._func_ptr = get_pointer(capsule, raw_signature)

self._signature = signature
self._pyfunc = pyfunc

def signature(self):
return numba.from_dtype(self._signature.res_dtype)(
*self._signature.arg_numba_types
)

def __wrapper_address__(self):
return self._func_ptr

def __call__(self, *args, **kwargs):
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
if self.has_pyx_skip_dispatch():
output = self._pyfunc(*args[:-1], **kwargs)
else:
output = self._pyfunc(*args, **kwargs)
return self._signature.res_dtype(output)

def has_pyx_skip_dispatch(self):
if not self._signature.arg_names:
return False
if any(
name == "__pyx_skip_dispatch" for name in self._signature.arg_names[:-1]
):
raise ValueError("skip_dispatch parameter must be last")
return self._signature.arg_names[-1] == "__pyx_skip_dispatch"

def numpy_arg_dtypes(self):
return self._signature.arg_dtypes

def numpy_output_dtype(self):
return self._signature.res_dtype


def wrap_cython_function(func, restype, arg_types):
impls = _available_impls(func)
compatible = []
for sig, capsule in impls:
if sig.provides(restype, arg_types):
compatible.append((sig, capsule))

def sort_key(args):
sig, _ = args

# Prefer functions with less inputs bytes
argsize = sum(np.dtype(dtype).itemsize for dtype in sig.arg_dtypes)

# Prefer functions with more exact (integer) arguments
num_inexact = sum(np.issubdtype(dtype, np.inexact) for dtype in sig.arg_dtypes)
return (num_inexact, argsize)

compatible.sort(key=sort_key)

if not compatible:
raise NotImplementedError(f"Could not find a compatible impl of {func}")
sig, capsule = compatible[0]
return _CythonWrapper(func, sig, capsule)
Loading