From 570f703ef919232db6bdaa84739621b67f57dba8 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Mon, 25 Sep 2023 16:53:02 +0200 Subject: [PATCH 1/4] Copy JAX backend to PyTorch and ask GPT to port it. --- pytensor/link/pytorch/dispatch/__init__.py | 17 + pytensor/link/pytorch/dispatch/basic.py | 109 +++++++ pytensor/link/pytorch/dispatch/elemwise.py | 113 +++++++ pytensor/link/pytorch/dispatch/extra_ops.py | 129 ++++++++ pytensor/link/pytorch/dispatch/nlinalg.py | 144 +++++++++ pytensor/link/pytorch/dispatch/random.py | 297 ++++++++++++++++++ pytensor/link/pytorch/dispatch/scalar.py | 282 +++++++++++++++++ pytensor/link/pytorch/dispatch/scan.py | 194 ++++++++++++ pytensor/link/pytorch/dispatch/shape.py | 114 +++++++ pytensor/link/pytorch/dispatch/slinalg.py | 43 +++ pytensor/link/pytorch/dispatch/sparse.py | 37 +++ pytensor/link/pytorch/dispatch/subtensor.py | 118 +++++++ .../link/pytorch/dispatch/tensor_basic.py | 207 ++++++++++++ 13 files changed, 1804 insertions(+) create mode 100644 pytensor/link/pytorch/dispatch/__init__.py create mode 100644 pytensor/link/pytorch/dispatch/basic.py create mode 100644 pytensor/link/pytorch/dispatch/elemwise.py create mode 100644 pytensor/link/pytorch/dispatch/extra_ops.py create mode 100644 pytensor/link/pytorch/dispatch/nlinalg.py create mode 100644 pytensor/link/pytorch/dispatch/random.py create mode 100644 pytensor/link/pytorch/dispatch/scalar.py create mode 100644 pytensor/link/pytorch/dispatch/scan.py create mode 100644 pytensor/link/pytorch/dispatch/shape.py create mode 100644 pytensor/link/pytorch/dispatch/slinalg.py create mode 100644 pytensor/link/pytorch/dispatch/sparse.py create mode 100644 pytensor/link/pytorch/dispatch/subtensor.py create mode 100644 pytensor/link/pytorch/dispatch/tensor_basic.py diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py new file mode 100644 index 0000000000..d091f8328a --- /dev/null +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -0,0 +1,17 @@ +# isort: off +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify + +# Load dispatch specializations +import pytensor.link.pytorch.dispatch.scalar +import pytensor.link.pytorch.dispatch.tensor_basic +import pytensor.link.pytorch.dispatch.subtensor +import pytensor.link.pytorch.dispatch.shape +import pytensor.link.pytorch.dispatch.extra_ops +import pytensor.link.pytorch.dispatch.nlinalg +import pytensor.link.pytorch.dispatch.slinalg +import pytensor.link.pytorch.dispatch.random +import pytensor.link.pytorch.dispatch.elemwise +import pytensor.link.pytorch.dispatch.scan +import pytensor.link.pytorch.dispatch.sparse + +# isort: on diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py new file mode 100644 index 0000000000..da382fddb4 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -0,0 +1,109 @@ +import warnings +from functools import singledispatch + +import torch +import numpy as np + +from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.ifelse import IfElse +from pytensor.link.utils import fgraph_to_python +from pytensor.raise_op import Assert, CheckAndRaise + + +@singledispatch +def torch_typify(data, dtype=None, **kwargs): + r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" + if dtype is None: + return data + else: + return torch.tensor(data, dtype=dtype) + + +@torch_typify.register(np.ndarray) +def torch_typify_ndarray(data, dtype=None, **kwargs): + if len(data.shape) == 0: + return data.item() + return torch.tensor(data, dtype=dtype) + + +@singledispatch +def torch_funcify(op, node=None, storage_map=None, **kwargs): + """Create a PyTorch compatible function from an PyTensor `Op`.""" + raise NotImplementedError(f"No PyTorch conversion for the given `Op`: {op}") + + +@torch_funcify.register(FunctionGraph) +def torch_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="torch_funcified_fgraph", + **kwargs, +): + return fgraph_to_python( + fgraph, + torch_funcify, + type_conversion_fn=torch_typify, + fgraph_name=fgraph_name, + **kwargs, + ) + + +@torch_funcify.register(IfElse) +def torch_funcify_IfElse(op, **kwargs): + n_outs = op.n_outs + + def ifelse(cond, *args, n_outs=n_outs): + res = torch.where( + cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None + ) + return res if n_outs > 1 else res[0] + + return ifelse + + +@torch_funcify.register(Assert) +@torch_funcify.register(CheckAndRaise) +def torch_funcify_CheckAndRaise(op, **kwargs): + warnings.warn( + f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as PyTorch tracing would remove it.""", + stacklevel=2, + ) + + def assert_fn(x, *inputs): + return x + + return assert_fn + + +def torch_safe_copy(x): + try: + res = torch.clone(x) + except NotImplementedError: + warnings.warn( + "`torch.clone` is not implemented yet. Using the object's `copy` method." + ) + if hasattr(x, "copy"): + res = torch.tensor(x.copy()) + else: + warnings.warn(f"Object has no `copy` method: {x}") + res = x + + return res + + +@torch_funcify.register(DeepCopyOp) +def torch_funcify_DeepCopyOp(op, **kwargs): + def deepcopyop(x): + return torch_safe_copy(x) + + return deepcopyop + + +@torch_funcify.register(ViewOp) +def torch_funcify_ViewOp(op, **kwargs): + def viewop(x): + return x + + return viewop \ No newline at end of file diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py new file mode 100644 index 0000000000..ab3c001230 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -0,0 +1,113 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise +from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad + + +@pytorch_funcify.register(Elemwise) +def pytorch_funcify_Elemwise(op, node, **kwargs): + scalar_op = op.scalar_op + base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) + + def elemwise_fn(*inputs): + # ScalarVariables in PyTorch are passed as int/float. + # We wrap them in tensors just for the broadcast check + Elemwise._check_runtime_broadcast(node, tuple(map(torch.tensor, inputs))) + return base_fn(*inputs) + + return elemwise_fn + + +@pytorch_funcify.register(CAReduce) +def pytorch_funcify_CAReduce(op, **kwargs): + axis = op.axis + op_nfunc_spec = getattr(op, "nfunc_spec", None) + scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None) + scalar_op_name = getattr(op.scalar_op, "name", None) + scalar_op_identity = getattr(op.scalar_op, "identity", None) + acc_dtype = getattr(op, "acc_dtype", None) + + def careduce(x): + nonlocal axis, op_nfunc_spec, scalar_nfunc_spec, scalar_op_name, scalar_op_identity, acc_dtype + + if axis is None: + axis = list(range(x.ndim)) + + if acc_dtype is None: + acc_dtype = x.dtype.type + + if op_nfunc_spec: + torch_op = getattr(torch, op_nfunc_spec[0]) + return torch_op(x, axis=axis).type(acc_dtype) + + # The PyTensor `Op` didn't tell us which PyTorch equivalent to use (or + # there isn't one), so we use this fallback approach + if scalar_nfunc_spec: + scalar_fn_name = scalar_nfunc_spec[0] + elif scalar_op_name: + scalar_fn_name = scalar_op_name + + to_reduce = sorted(axis, reverse=True) + + if to_reduce: + # In this case, we need to use the `torch` function (if there + # is one), and not the `torch` version. + torch_op = getattr(torch, scalar_fn_name) + init_value = torch.tensor(scalar_op_identity, dtype=acc_dtype) + return torch.reduce(x, init_value, torch_op, to_reduce).type(acc_dtype) + else: + return x + + return careduce + + +@pytorch_funcify.register(DimShuffle) +def pytorch_funcify_DimShuffle(op, **kwargs): + def dimshuffle(x): + res = torch.transpose(x, op.transposition) + + shape = list(res.shape[: len(op.shuffle)]) + + for augm in op.augment: + shape.insert(augm, 1) + + res = torch.reshape(res, shape) + + if not op.inplace: + res = torch.clone(res) + + return res + + return dimshuffle + + +@pytorch_funcify.register(Softmax) +def pytorch_funcify_Softmax(op, **kwargs): + axis = op.axis + + def softmax(x): + return torch.nn.functional.softmax(x, dim=axis) + + return softmax + + +@pytorch_funcify.register(SoftmaxGrad) +def pytorch_funcify_SoftmaxGrad(op, **kwargs): + axis = op.axis + + def softmax_grad(dy, sm): + dy_times_sm = dy * sm + return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm + + return softmax_grad + + +@pytorch_funcify.register(LogSoftmax) +def pytorch_funcify_LogSoftmax(op, **kwargs): + axis = op.axis + + def log_softmax(x): + return torch.nn.functional.log_softmax(x, dim=axis) + + return log_softmax diff --git a/pytensor/link/pytorch/dispatch/extra_ops.py b/pytensor/link/pytorch/dispatch/extra_ops.py new file mode 100644 index 0000000000..366b6fa9f5 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/extra_ops.py @@ -0,0 +1,129 @@ +import warnings +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.extra_ops import ( + Bartlett, + CumOp, + FillDiagonal, + FillDiagonalOffset, + RavelMultiIndex, + Repeat, + Unique, + UnravelIndex, +) + + +@pytorch_funcify.register(Bartlett) +def pytorch_funcify_Bartlett(op, **kwargs): + def bartlett(x): + return torch.bartlett_window(x) + + return bartlett + + +@pytorch_funcify.register(CumOp) +def pytorch_funcify_CumOp(op, **kwargs): + axis = op.axis + mode = op.mode + + def cumop(x, axis=axis, mode=mode): + if mode == "add": + return torch.cumsum(x, axis=axis) + else: + return torch.cumprod(x, axis=axis) + + return cumop + + +@pytorch_funcify.register(Repeat) +def pytorch_funcify_Repeat(op, **kwargs): + axis = op.axis + + def repeatop(x, repeats, axis=axis): + return x.repeat(repeats, axis=axis) + + return repeatop + + +@pytorch_funcify.register(Unique) +def pytorch_funcify_Unique(op, **kwargs): + axis = op.axis + + if axis is not None: + raise NotImplementedError( + "torch.unique is not implemented for the axis argument" + ) + + return_index = op.return_index + return_inverse = op.return_inverse + return_counts = op.return_counts + + def unique( + x, + return_index=return_index, + return_inverse=return_inverse, + return_counts=return_counts, + axis=axis, + ): + ret = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts) + if len(ret) == 1: + return ret[0] + else: + return ret + + return unique + + +@pytorch_funcify.register(UnravelIndex) +def pytorch_funcify_UnravelIndex(op, **kwargs): + order = op.order + + warnings.warn("PyTorch ignores the `order` parameter in `unravel_index`.") + + def unravelindex(indices, dims, order=order): + return torch.unravel_index(indices, dims) + + return unravelindex + + +@pytorch_funcify.register(RavelMultiIndex) +def pytorch_funcify_RavelMultiIndex(op, **kwargs): + mode = op.mode + order = op.order + + def ravelmultiindex(*inp, mode=mode, order=order): + multi_index, dims = inp[:-1], inp[-1] + return torch.ravel_multi_index(multi_index, dims, mode=mode, order=order) + + return ravelmultiindex + + +@pytorch_funcify.register(FillDiagonal) +def pytorch_funcify_FillDiagonal(op, **kwargs): + def filldiagonal(value, diagonal): + value.fill_diagonal_(diagonal) + return value + + return filldiagonal + + +@pytorch_funcify.register(FillDiagonalOffset) +def pytorch_funcify_FillDiagonalOffset(op, **kwargs): + def filldiagonaloffset(a, val, offset): + height, width = a.shape + + if offset >= 0: + start = offset + num_of_step = min(min(width, height), width - offset) + else: + start = -offset * a.shape[1] + num_of_step = min(min(width, height), height + offset) + + step = a.shape[1] + 1 + end = start + step * num_of_step + a.view(-1)[start:end:step] = val + + return a + + return filldiagonaloffset diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py new file mode 100644 index 0000000000..a174964699 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -0,0 +1,144 @@ +import torch + +from pytensor.link.pytorch.dispatch import pytorch_funcify +from pytensor.tensor.blas import BatchedDot +from pytensor.tensor.math import Dot, MaxAndArgmax +from pytensor.tensor.nlinalg import ( + SVD, + Det, + Eig, + Eigh, + MatrixInverse, + MatrixPinv, + QRFull, + SLogDet, +) + + +@pytorch_funcify.register(SVD) +def pytorch_funcify_SVD(op, **kwargs): + full_matrices = op.full_matrices + compute_uv = op.compute_uv + + def svd(x, full_matrices=full_matrices, compute_uv=compute_uv): + return torch.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) + + return svd + + +@pytorch_funcify.register(Det) +def pytorch_funcify_Det(op, **kwargs): + def det(x): + return torch.det(x) + + return det + + +@pytorch_funcify.register(SLogDet) +def pytorch_funcify_SLogDet(op, **kwargs): + def slogdet(x): + return torch.slogdet(x) + + return slogdet + + +@pytorch_funcify.register(Eig) +def pytorch_funcify_Eig(op, **kwargs): + def eig(x): + return torch.eig(x) + + return eig + + +@pytorch_funcify.register(Eigh) +def pytorch_funcify_Eigh(op, **kwargs): + uplo = op.UPLO + + def eigh(x, uplo=uplo): + return torch.linalg.eigh(x, UPLO=uplo) + + return eigh + + +@pytorch_funcify.register(MatrixInverse) +def pytorch_funcify_MatrixInverse(op, **kwargs): + def matrix_inverse(x): + return torch.inverse(x) + + return matrix_inverse + + +@pytorch_funcify.register(QRFull) +def pytorch_funcify_QRFull(op, **kwargs): + mode = op.mode + + def qr_full(x, mode=mode): + return torch.qr(x, mode=mode) + + return qr_full + + +@pytorch_funcify.register(Dot) +def pytorch_funcify_Dot(op, **kwargs): + def dot(x, y): + return torch.dot(x, y) + + return dot + + +@pytorch_funcify.register(MatrixPinv) +def pytorch_funcify_Pinv(op, **kwargs): + def pinv(x): + return torch.pinverse(x) + + return pinv + + +@pytorch_funcify.register(BatchedDot) +def pytorch_funcify_BatchedDot(op, **kwargs): + def batched_dot(a, b): + if a.shape[0] != b.shape[0]: + raise TypeError("Shapes must match in the 0-th dimension") + if a.ndim == 2 or b.ndim == 2: + return torch.einsum("n...j,nj...->n...", a, b) + return torch.einsum("nij,njk->nik", a, b) + + return batched_dot + + +@pytorch_funcify.register(MaxAndArgmax) +def pytorch_funcify_MaxAndArgmax(op, **kwargs): + axis = op.axis + + def maxandargmax(x, axis=axis): + if axis is None: + axes = tuple(range(x.ndim)) + else: + axes = tuple(int(ax) for ax in axis) + + max_res = torch.max(x, axis) + + # NumPy does not support multiple axes for argmax; this is a + # work-around + keep_axes = torch.tensor( + [i for i in range(x.ndim) if i not in axes], dtype=torch.int64 + ) + # Not-reduced axes in front + transposed_x = torch.transpose( + x, torch.cat((keep_axes, torch.tensor(axes, dtype=torch.int64))) + ) + kept_shape = transposed_x.shape[: len(keep_axes)] + reduced_shape = transposed_x.shape[len(keep_axes) :] + + # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 + # Otherwise reshape would complain citing float arg + new_shape = kept_shape + ( + torch.prod(torch.tensor(reduced_shape, dtype=torch.int64), dtype=torch.int64), + ) + reshaped_x = transposed_x.reshape(new_shape) + + max_idx_res = torch.argmax(reshaped_x, axis=-1).type(torch.int64) + + return max_res, max_idx_res + + return maxandargmax \ No newline at end of file diff --git a/pytensor/link/pytorch/dispatch/random.py b/pytensor/link/pytorch/dispatch/random.py new file mode 100644 index 0000000000..e5d881b35c --- /dev/null +++ b/pytensor/link/pytorch/dispatch/random.py @@ -0,0 +1,297 @@ +import torch +from functools import singledispatch + +import numpy as np +from numpy.random import Generator, RandomState +from numpy.random.bit_generator import ( # type: ignore[attr-defined] + _coerce_to_uint32_array, +) + +import pytensor.tensor.random.basic as aer +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify +from pytensor.link.pytorch.dispatch.shape import PyTorchShapeTuple +from pytensor.tensor.shape import Shape, Shape_i + +numpy_bit_gens = {"MT19937": 0, "PCG64": 1, "Philox": 2, "SFC64": 3} + +SIZE_NOT_COMPATIBLE = """PyTorch random variables require concrete values for the `size` parameter of the distributions. +Concrete values are either constants: + +>>> import pytensor.tensor as at +>>> x_rv = at.random.normal(0, 1, size=(3, 2)) + +or the shape of an array: + +>>> m = at.matrix() +>>> x_rv = at.random.normal(0, 1, size=m.shape) +""" + +def assert_size_argument_pytorch_compatible(node): + """Assert whether the current node can be JIT-compiled by PyTorch. + + PyTorch can JIT-compile `torch.random` functions when the `size` argument + is a concrete value, i.e. either a constant or the shape of any + traced value. + + """ + size = node.inputs[1] + size_node = size.owner + if (size_node is not None) and ( + not isinstance(size_node.op, (Shape, Shape_i, PyTorchShapeTuple)) + ): + raise NotImplementedError(SIZE_NOT_COMPATIBLE) + +@pytorch_typify.register(RandomState) +def pytorch_typify_RandomState(state, **kwargs): + state = state.get_state(legacy=False) + state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] + # XXX: Is this a reasonable approach? + state["pytorch_state"] = state["state"]["key"][0:2] + return state + +@pytorch_typify.register(Generator) +def pytorch_typify_Generator(rng, **kwargs): + state = rng.__getstate__() + state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] + + # XXX: Is this a reasonable approach? + state["pytorch_state"] = _coerce_to_uint32_array(state["state"]["state"])[0:2] + + # The "state" and "inc" values in a NumPy `Generator` are 128 bits, which + # PyTorch can't handle, so we split these values into arrays of 32 bit integers + # and then combine the first two into a single 64 bit integers. + # + # XXX: Depending on how we expect these values to be used, is this approach + # reasonable? + # + # TODO: We might as well remove these altogether, since this conversion + # should only occur once (e.g. when the graph is converted/PyTorch-compiled), + # and, from then on, we use the custom "pytorch_state" value. + inc_32 = _coerce_to_uint32_array(state["state"]["inc"]) + state_32 = _coerce_to_uint32_array(state["state"]["state"]) + state["state"]["inc"] = inc_32[0] << 32 | inc_32[1] + state["state"]["state"] = state_32[0] << 32 | state_32[1] + return state + +@pytorch_funcify.register(aer.RandomVariable) +def pytorch_funcify_RandomVariable(op, node, **kwargs): + """PyTorch implementation of random variables.""" + rv = node.outputs[1] + out_dtype = rv.type.dtype + out_size = rv.type.shape + + if op.ndim_supp > 0: + out_size = node.outputs[1].type.shape[: -op.ndim_supp] + + # If one dimension has unknown size, either the size is determined + # by a `Shape` operator in which case PyTorch will compile, or it is + # not and we fail gracefully. + if None in out_size: + assert_size_argument_pytorch_compatible(node) + + def sample_fn(rng, size, dtype, *parameters): + return pytorch_sample_fn(op)(rng, size, out_dtype, *parameters) + + else: + + def sample_fn(rng, size, dtype, *parameters): + return pytorch_sample_fn(op)(rng, out_size, out_dtype, *parameters) + + return sample_fn + +@singledispatch +def pytorch_sample_fn(op): + name = op.name + raise NotImplementedError( + f"No PyTorch implementation for the given distribution: {name}" + ) + +@pytorch_sample_fn.register(aer.BetaRV) +@pytorch_sample_fn.register(aer.DirichletRV) +@pytorch_sample_fn.register(aer.PoissonRV) +@pytorch_sample_fn.register(aer.MvNormalRV) +def pytorch_sample_fn_generic(op): + """Generic PyTorch implementation of random variables.""" + name = op.name + pytorch_op = getattr(torch.distributions, name) + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["pytorch_state"] + sample = pytorch_op(*parameters).sample(size) + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn + +@pytorch_sample_fn.register(aer.CauchyRV) +@pytorch_sample_fn.register(aer.GumbelRV) +@pytorch_sample_fn.register(aer.LaplaceRV) +@pytorch_sample_fn.register(aer.LogisticRV) +@pytorch_sample_fn.register(aer.NormalRV) +@pytorch_sample_fn.register(aer.StandardNormalRV) +def pytorch_sample_fn_loc_scale(op): + """PyTorch implementation of random variables in the loc-scale families. + + PyTorch only implements the standard version of random variables in the + loc-scale family. We thus need to translate and rescale the results + manually. + + """ + name = op.name + pytorch_op = getattr(torch.distributions, name) + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["pytorch_state"] + loc, scale = parameters + sample = loc + pytorch_op(*parameters).sample(size) * scale + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn + +@pytorch_sample_fn.register(aer.BernoulliRV) +@pytorch_sample_fn.register(aer.CategoricalRV) +def pytorch_sample_fn_no_dtype(op): + """Generic PyTorch implementation of random variables.""" + name = op.name + pytorch_op = getattr(torch.distributions, name) + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["pytorch_state"] + sample = pytorch_op(*parameters).sample(size) + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn + +@pytorch_sample_fn.register(aer.RandIntRV) +@pytorch_sample_fn.register(aer.IntegersRV) +@pytorch_sample_fn.register(aer.UniformRV) +def pytorch_sample_fn_uniform(op): + """PyTorch implementation of random variables with uniform density. + + We need to pass the arguments as keyword arguments since the order + of arguments is not the same. + + """ + name = op.name + # IntegersRV is equivalent to RandintRV + if isinstance(op, aer.IntegersRV): + name = "randint" + pytorch_op = getattr(torch.distributions, name) + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["pytorch_state"] + minval, maxval = parameters + sample = pytorch_op(*parameters).sample(size) + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn + +@pytorch_sample_fn.register(aer.ParetoRV) +@pytorch_sample_fn.register(aer.GammaRV) +def pytorch_sample_fn_shape_rate(op): + """PyTorch implementation of random variables in the shape-rate family. + + PyTorch only implements the standard version of random variables in the + shape-rate family. We thus need to rescale the results manually. + + """ + name = op.name + pytorch_op = getattr(torch.distributions, name) + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["pytorch_state"] + (shape, rate) = parameters + sample = pytorch_op(*parameters).sample(size) / rate + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn + +@pytorch_sample_fn.register(aer.ExponentialRV) +def pytorch_sample_fn_exponential(op): + """PyTorch implementation of `ExponentialRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["pytorch_state"] + (scale,) = parameters + sample = torch.distributions.Exponential(scale).sample(size) + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn + +@pytorch_sample_fn.register(aer.StudentTRV) +def pytorch_sample_fn_t(op): + """PyTorch implementation of `StudentTRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["pytorch_state"] + ( + df, + loc, + scale, + ) = parameters + sample = loc + torch.distributions.StudentT(df).sample(size) * scale + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn + +@pytorch_sample_fn.register(aer.ChoiceRV) +def pytorch_funcify_choice(op): + """PyTorch implementation of `ChoiceRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["pytorch_state"] + (a, p, replace) = parameters + smpl_value = torch.multinomial(p, size, replacement=replace) + rng["pytorch_state"] = rng_key + return (rng, smpl_value) + + return sample_fn + +@pytorch_sample_fn.register(aer.PermutationRV) +def pytorch_sample_fn_permutation(op): + """PyTorch implementation of `PermutationRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["pytorch_state"] + (x,) = parameters + sample = torch.randperm(x) + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn + +@pytorch_sample_fn.register(aer.BinomialRV) +def pytorch_sample_fn_binomial(op): + def sample_fn(rng, size, dtype, n, p): + rng_key = rng["pytorch_state"] + sample = torch.distributions.Binomial(n, p).sample(size) + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn + +@pytorch_sample_fn.register(aer.MultinomialRV) +def pytorch_sample_fn_multinomial(op): + def sample_fn(rng, size, dtype, n, p): + rng_key = rng["pytorch_state"] + sample = torch.distributions.Multinomial(n, p).sample(size) + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn + +@pytorch_sample_fn.register(aer.VonMisesRV) +def pytorch_sample_fn_vonmises(op): + def sample_fn(rng, size, dtype, mu, kappa): + rng_key = rng["pytorch_state"] + sample = torch.distributions.VonMises(mu, kappa).sample(size) + rng["pytorch_state"] = rng_key + return (rng, sample) + + return sample_fn \ No newline at end of file diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py new file mode 100644 index 0000000000..d65cc72da6 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -0,0 +1,282 @@ +import functools +import typing +from typing import Callable, Optional + +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.scalar import Softplus +from pytensor.scalar.basic import ( + Add, + Cast, + Clip, + Composite, + Identity, + IntDiv, + Mod, + Mul, + ScalarOp, + Second, + Sub, +) +from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi + + +def try_import_torch_op(op: ScalarOp, torch_op_name: Optional[str] = None) -> Callable: + try: + import torch + except ModuleNotFoundError: + raise NotImplementedError( + f"No PyTorch implementation for Op {op.name}. " + "Implementation is available if PyTorch is installed" + ) + + if torch_op_name is None: + torch_op_name = op.name + return typing.cast(Callable, getattr(torch, torch_op_name)) + + +def check_if_inputs_scalars(node): + """Check whether all the inputs of an `Elemwise` are scalar values. + + `torch` functions systematically return `Tensors`, + while the corresponding Python operators return concrete values when passed + concrete values. In order to be able to compile the largest number of graphs + possible we need to preserve concrete values whenever we can. We thus need + to dispatch differently the PyTensor operators depending on whether the inputs + are scalars. + + """ + ndims_input = [inp.type.ndim for inp in node.inputs] + are_inputs_scalars = True + for ndim in ndims_input: + try: + if ndim > 0: + are_inputs_scalars = False + except TypeError: + are_inputs_scalars = False + + return are_inputs_scalars + + +@pytorch_funcify.register(ScalarOp) +def pytorch_funcify_ScalarOp(op, node, **kwargs): + func_name = op.nfunc_spec[0] + + # We dispatch some PyTensor operators to Python operators + # whenever the inputs are all scalars. + are_inputs_scalars = check_if_inputs_scalars(node) + if are_inputs_scalars: + elemwise = elemwise_scalar(op) + if elemwise is not None: + return elemwise + + if "." in func_name: + torch_func = functools.reduce(getattr, [torch] + func_name.split(".")) + else: + torch_func = getattr(torch, func_name) + + if hasattr(op, "nfunc_variadic"): + # These are special cases that handle invalid arities due to the broken + # PyTensor `Op` type contract (e.g. binary `Op`s that also function as + # their own variadic counterparts--even when those counterparts already + # exist as independent `Op`s). + torch_variadic_func = getattr(torch, op.nfunc_variadic) + + def elemwise(*args): + if len(args) > op.nfunc_spec[1]: + return torch_variadic_func( + torch.stack(torch.broadcast_tensors(*args), dim=0), dim=0 + ) + else: + return torch_func(*args) + + return elemwise + else: + return torch_func + + +@functools.singledispatch +def elemwise_scalar(op): + return None + + +@elemwise_scalar.register(Add) +def elemwise_scalar_add(op): + def elemwise(*inputs): + return sum(inputs) + + return elemwise + + +@elemwise_scalar.register(Mul) +def elemwise_scalar_mul(op): + import operator + from functools import reduce + + def elemwise(*inputs): + return reduce(operator.mul, inputs, 1) + + return elemwise + + +@elemwise_scalar.register(Sub) +def elemwise_scalar_sub(op): + def elemwise(x, y): + return x - y + + return elemwise + + +@elemwise_scalar.register(IntDiv) +def elemwise_scalar_intdiv(op): + def elemwise(x, y): + return x // y + + return elemwise + + +@elemwise_scalar.register(Mod) +def elemwise_scalar_mod(op): + def elemwise(x, y): + return x % y + + return elemwise + + +@pytorch_funcify.register(Cast) +def pytorch_funcify_Cast(op, **kwargs): + def cast(x): + return torch.tensor(x).type(op.o_type.dtype) + + return cast + + +@pytorch_funcify.register(Identity) +def pytorch_funcify_Identity(op, **kwargs): + def identity(x): + return x + + return identity + + +@pytorch_funcify.register(Clip) +def pytorch_funcify_Clip(op, **kwargs): + """Register the translation for the `Clip` `Op`. + + PyTensor's `Clip` operator operates differently from PyTorch's when the + specified `min` is larger than the `max` so we cannot reuse `torch.clip` + to maintain consistency with PyTensor. + + """ + + def clip(x, min, max): + return torch.where(x < min, min, torch.where(x > max, max, x)) + + return clip + + +@pytorch_funcify.register(Composite) +def pytorch_funcify_Composite(op, node, vectorize=True, **kwargs): + pytorch_impl = pytorch_funcify(op.fgraph) + + if len(node.outputs) == 1: + + def composite(*args): + return pytorch_impl(*args)[0] + + else: + + def composite(*args): + return pytorch_impl(*args) + + return torch.vectorize(composite) + + +@pytorch_funcify.register(Second) +def pytorch_funcify_Second(op, **kwargs): + def second(x, y): + _, y = torch.broadcast_tensors(x, y) + return y + + return second + + +@pytorch_funcify.register(Erf) +def pytorch_funcify_Erf(op, node, **kwargs): + def erf(x): + return torch.erf(x) + + return erf + + +@pytorch_funcify.register(Erfc) +def pytorch_funcify_Erfc(op, **kwargs): + def erfc(x): + return torch.erfc(x) + + return erfc + + +@pytorch_funcify.register(Erfinv) +def pytorch_funcify_Erfinv(op, **kwargs): + def erfinv(x): + return torch.erfinv(x) + + return erfinv + + +@pytorch_funcify.register(Erfcx) +@pytorch_funcify.register(Erfcinv) +def pytorch_funcify_from_tfp(op, **kwargs): + torch_op = try_import_torch_op(op) + + return torch_op + + +@pytorch_funcify.register(Iv) +def pytorch_funcify_Iv(op, **kwargs): + ive = try_import_torch_op(op, torch_op_name="bessel_ive") + + def iv(v, x): + return ive(v, x) / torch.exp(-torch.abs(torch.real(x))) + + return iv + + +@pytorch_funcify.register(Log1mexp) +def pytorch_funcify_Log1mexp(op, node, **kwargs): + def log1mexp(x): + return torch.where( + x < torch.log(0.5), torch.log1p(-torch.exp(x)), torch.log(-torch.expm1(x)) + ) + + return log1mexp + + +@pytorch_funcify.register(Psi) +def pytorch_funcify_Psi(op, node, **kwargs): + def psi(x): + return torch.digamma(x) + + return psi + + +@pytorch_funcify.register(Softplus) +def pytorch_funcify_Softplus(op, **kwargs): + def softplus(x): + return torch.where( + x < -37.0, + torch.exp(x), + torch.where( + x < 18.0, + torch.log1p(torch.exp(x)), + torch.where( + x < 33.3, + x + torch.exp(-x), + x, + ), + ), + ) + + return softplus \ No newline at end of file diff --git a/pytensor/link/pytorch/dispatch/scan.py b/pytensor/link/pytorch/dispatch/scan.py new file mode 100644 index 0000000000..9730a190ab --- /dev/null +++ b/pytensor/link/pytorch/dispatch/scan.py @@ -0,0 +1,194 @@ +import torch + +from pytensor.compile.mode import PyTorch +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.scan.op import Scan + + +@pytorch_funcify.register(Scan) +def pytorch_funcify_Scan(op: Scan, **kwargs): + info = op.info + + if info.as_while: + raise NotImplementedError("While Scan cannot yet be converted to PyTorch") + + if info.n_mit_mot: + raise NotImplementedError( + "Scan with MIT-MOT (gradients of scan) cannot yet be converted to PyTorch" + ) + + # Optimize inner graph (exclude any defalut rewrites that are incompatible with PyTorch mode) + rewriter = op.mode_instance.excluding(*PyTorch._optimizer.exclude).optimizer + rewriter(op.fgraph) + scan_inner_func = pytorch_funcify(op.fgraph, **kwargs) + + def scan(*outer_inputs): + # Extract PyTorch scan inputs + outer_inputs = list(outer_inputs) + n_steps = outer_inputs[0] # PyTorch `length` + seqs = op.outer_seqs(outer_inputs) # PyTorch `xs` + + mit_sot_init = [] + for tap, seq in zip(op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs)): + init_slice = seq[: abs(min(tap))] + mit_sot_init.append(init_slice) + + sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)] + + init_carry = ( + mit_sot_init, + sit_sot_init, + op.outer_shared(outer_inputs), + op.outer_non_seqs(outer_inputs), + ) # PyTorch `init` + + def pytorch_args_to_inner_func_args(carry, x): + """Convert PyTorch scan arguments into format expected by scan_inner_func. + + scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs) + """ + + # `carry` contains all inner taps, shared terms, and non_seqs + ( + inner_mit_sot, + inner_sit_sot, + inner_shared, + inner_non_seqs, + ) = carry + + # `x` contains the inner sequences + inner_seqs = x + + mit_sot_flatten = [] + for array, index in zip(inner_mit_sot, op.info.mit_sot_in_slices): + mit_sot_flatten.extend(array[torch.tensor(index)]) + + inner_scan_inputs = [ + *inner_seqs, + *mit_sot_flatten, + *inner_sit_sot, + *inner_shared, + *inner_non_seqs, + ] + + return inner_scan_inputs + + def inner_func_outs_to_pytorch_outs( + old_carry, + inner_scan_outs, + ): + """Convert inner_scan_func outputs into format expected by PyTorch scan. + + old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys) + """ + ( + inner_mit_sot, + inner_sit_sot, + inner_shared, + inner_non_seqs, + ) = old_carry + + inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs) + inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs) + inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs) + inner_shared_outs = op.inner_shared_outs(inner_scan_outs) + + # Replace the oldest mit_sot tap by the newest value + inner_mit_sot_new = [ + torch.cat([old_mit_sot[1:], new_val[None, ...]], dim=0) + for old_mit_sot, new_val in zip( + inner_mit_sot, + inner_mit_sot_outs, + ) + ] + + # Nothing needs to be done with sit_sot + inner_sit_sot_new = inner_sit_sot_outs + + inner_shared_new = inner_shared + # Replace old shared inputs by new shared outputs + inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs + + new_carry = ( + inner_mit_sot_new, + inner_sit_sot_new, + inner_shared_new, + inner_non_seqs, + ) + + # Shared variables and non_seqs are not traced + traced_outs = [ + *inner_mit_sot_outs, + *inner_sit_sot_outs, + *inner_nit_sot_outs, + ] + + return new_carry, traced_outs + + def pytorch_inner_func(carry, x): + inner_args = pytorch_args_to_inner_func_args(carry, x) + inner_scan_outs = list(scan_inner_func(*inner_args)) + new_carry, traced_outs = inner_func_outs_to_pytorch_outs(carry, inner_scan_outs) + return new_carry, traced_outs + + # Extract PyTensor scan outputs + final_carry, traces = torch.scan( + pytorch_inner_func, init_carry, seqs, length=n_steps + ) + + def get_partial_traces(traces): + """Convert PyTorch scan traces to PyTensor traces. + + We need to: + 1. Prepend initial states to PyTorch output traces + 2. Slice final traces if Scan was instructed to only keep a portion + """ + + init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot + buffers = ( + op.outer_mitsot(outer_inputs) + + op.outer_sitsot(outer_inputs) + + op.outer_nitsot(outer_inputs) + ) + partial_traces = [] + for init_state, trace, buffer in zip(init_states, traces, buffers): + if init_state is not None: + # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer + trace = torch.atleast_1d(trace) + init_state = torch.unsqueeze( + init_state, list(range(trace.ndim - init_state.ndim)) + ) + full_trace = torch.cat([init_state, trace], dim=0) + buffer_size = buffer.shape[0] + else: + # NIT-SOT: Buffer is just the number of entries that should be returned + full_trace = torch.atleast_1d(trace) + buffer_size = buffer + + partial_trace = full_trace[-buffer_size:] + partial_traces.append(partial_trace) + + return partial_traces + + def get_shared_outs(final_carry): + """Retrive last state of shared_outs from final_carry. + + These outputs cannot be traced in PyTensor Scan + """ + ( + inner_out_mit_sot, + inner_out_sit_sot, + inner_out_shared, + inner_in_non_seqs, + ) = final_carry + + shared_outs = inner_out_shared[: info.n_shared_outs] + return list(shared_outs) + + scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry) + + if len(scan_outs_final) == 1: + scan_outs_final = scan_outs_final[0] + return scan_outs_final + + return scan \ No newline at end of file diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py new file mode 100644 index 0000000000..a912e89aad --- /dev/null +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -0,0 +1,114 @@ +import torch + +from pytensor.graph import Constant +from pytensor.graph.basic import Apply +from pytensor.graph.op import Op +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast +from pytensor.tensor.type import TensorType + + +class PyTorchShapeTuple(Op): + """Dummy Op that represents a `size` specified as a tuple.""" + + def make_node(self, *inputs): + dtype = inputs[0].type.dtype + otype = TensorType(dtype, shape=(len(inputs),)) + return Apply(self, inputs, [otype()]) + + def perform(self, *inputs): + return tuple(inputs) + + +@pytorch_funcify.register(PyTorchShapeTuple) +def pytorch_funcify_PyTorchShapeTuple(op, **kwargs): + def shape_tuple_fn(*x): + return tuple(x) + + return shape_tuple_fn + + +SHAPE_NOT_COMPATIBLE = """PyTorch requires concrete values for the `shape` parameter of `torch.reshape`. +Concrete values are either constants: + +>>> import pytensor.tensor as at +>>> x = at.ones(6) +>>> y = x.reshape((2, 3)) + +Or the shape of an array: + +>>> mat = at.matrix('mat') +>>> y = x.reshape(mat.shape) +""" + + +def assert_shape_argument_pytorch_compatible(shape): + """Assert whether the current node can be JIT-compiled by PyTorch. + + PyTorch can JIT-compile functions with a `shape` or `size` argument if it is + given a concrete value, i.e. either a constant or the shape of any traced + value. + + """ + shape_op = shape.owner.op + if not isinstance(shape_op, (Shape, Shape_i, PyTorchShapeTuple)): + raise NotImplementedError(SHAPE_NOT_COMPATIBLE) + + +@pytorch_funcify.register(Reshape) +def pytorch_funcify_Reshape(op, node, **kwargs): + shape = node.inputs[1] + + if isinstance(shape, Constant): + constant_shape = shape.data + + def reshape(x, shape): + return torch.reshape(x, constant_shape) + + else: + assert_shape_argument_pytorch_compatible(shape) + + def reshape(x, shape): + return torch.reshape(x, shape) + + return reshape + + +@pytorch_funcify.register(Shape) +def pytorch_funcify_Shape(op, **kwargs): + def shape(x): + return torch.shape(x) + + return shape + + +@pytorch_funcify.register(Shape_i) +def pytorch_funcify_Shape_i(op, **kwargs): + i = op.i + + def shape_i(x): + return torch.shape(x)[i] + + return shape_i + + +@pytorch_funcify.register(SpecifyShape) +def pytorch_funcify_SpecifyShape(op, node, **kwargs): + def specifyshape(x, *shape): + assert x.ndim == len(shape) + for actual, expected in zip(x.shape, shape): + if expected is None: + continue + if actual != expected: + raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}") + return x + + return specifyshape + + +@pytorch_funcify.register(Unbroadcast) +def pytorch_funcify_Unbroadcast(op, **kwargs): + def unbroadcast(x): + return x + + return unbroadcast \ No newline at end of file diff --git a/pytensor/link/pytorch/dispatch/slinalg.py b/pytensor/link/pytorch/dispatch/slinalg.py new file mode 100644 index 0000000000..2d783097b9 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/slinalg.py @@ -0,0 +1,43 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import torch_funcify +from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular + + +@torch_funcify.register(Cholesky) +def torch_funcify_Cholesky(op, **kwargs): + lower = op.lower + + def cholesky(a, lower=lower): + return torch.cholesky(a, upper=not lower).to(a.dtype) + + return cholesky + + +@torch_funcify.register(Solve) +def torch_funcify_Solve(op, **kwargs): + if op.assume_a != "gen" and op.lower: + lower = True + else: + lower = False + + def solve(a, b, lower=lower): + if lower: + return torch.triangular_solve(b, a, upper=False)[0] + else: + return torch.solve(b, a)[0] + + return solve + + +@torch_funcify.register(SolveTriangular) +def torch_funcify_SolveTriangular(op, **kwargs): + lower = op.lower + trans = op.trans + unit_diagonal = op.unit_diagonal + check_finite = op.check_finite + + def solve_triangular(A, b): + return torch.triangular_solve(b, A, upper=not lower, transpose=trans)[0] + + return solve_triangular \ No newline at end of file diff --git a/pytensor/link/pytorch/dispatch/sparse.py b/pytensor/link/pytorch/dispatch/sparse.py new file mode 100644 index 0000000000..ec00147028 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/sparse.py @@ -0,0 +1,37 @@ +import torch +from scipy.sparse import spmatrix + +from pytensor.graph.basic import Constant +from pytensor.link.pytorch.dispatch import pytorch_funcify, pytorch_typify +from pytensor.sparse.basic import Dot, StructuredDot +from pytensor.sparse.type import SparseTensorType + + +@pytorch_typify.register(spmatrix) +def pytorch_typify_spmatrix(matrix, dtype=None, **kwargs): + # Note: This changes the type of the constants from CSR/CSC to COO + # We could add COO as a PyTensor type but this would only be useful for PyTorch graphs + # and it would break the premise of one graph -> multiple backends. + # The same situation happens with RandomGenerators... + return torch.sparse_coo_tensor(matrix.indices, matrix.data, matrix.shape) + + +@pytorch_funcify.register(Dot) +@pytorch_funcify.register(StructuredDot) +def pytorch_funcify_sparse_dot(op, node, **kwargs): + for input in node.inputs: + if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant): + raise NotImplementedError( + "PyTorch sparse dot only implemented for constant sparse inputs" + ) + + if isinstance(node.outputs[0].type, SparseTensorType): + raise NotImplementedError("PyTorch sparse dot only implemented for dense outputs") + + def sparse_dot(x, y): + out = torch.sparse.mm(x, y) + if out.is_sparse: + out = out.to_dense() + return out + + return sparse_dot \ No newline at end of file diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py new file mode 100644 index 0000000000..1ec8384283 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -0,0 +1,118 @@ +import torch +from pytensor.link.pytorch.dispatch.basic import torch_funcify +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, + indices_from_subtensor, +) +from pytensor.tensor.type_other import MakeSlice + +BOOLEAN_MASK_ERROR = """PyTorch does not support resizing arrays with boolean +masks. In some cases, however, it is possible to re-express your model +in a form that PyTorch can compile: + +>>> import pytensor.tensor as at +>>> x_at = at.vector('x') +>>> y_at = x_at[x_at > 0].sum() + +can be re-expressed as: + +>>> import pytensor.tensor as at +>>> x_at = at.vector('x') +>>> y_at = at.where(x_at > 0, x_at, 0).sum() +""" + +DYNAMIC_SLICE_LENGTH_ERROR = """PyTorch does not support slicing arrays with a dynamic +slice length. +""" + + +def subtensor_assert_indices_torch_compatible(node, idx_list): + from pytensor.graph.basic import Constant + from pytensor.tensor.variable import TensorVariable + + ilist = indices_from_subtensor(node.inputs[1:], idx_list) + for idx in ilist: + if isinstance(idx, TensorVariable): + if idx.type.dtype == "bool": + raise NotImplementedError(BOOLEAN_MASK_ERROR) + elif isinstance(idx, slice): + for slice_arg in (idx.start, idx.stop, idx.step): + if slice_arg is not None and not isinstance(slice_arg, Constant): + raise NotImplementedError(DYNAMIC_SLICE_LENGTH_ERROR) + + +@torch_funcify.register(Subtensor) +@torch_funcify.register(AdvancedSubtensor) +@torch_funcify.register(AdvancedSubtensor1) +def torch_funcify_Subtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + subtensor_assert_indices_torch_compatible(node, idx_list) + + def subtensor_constant(x, *ilists): + indices = indices_from_subtensor(ilists, idx_list) + if len(indices) == 1: + indices = indices[0] + + return x.__getitem__(indices) + + return subtensor_constant + + +@torch_funcify.register(IncSubtensor) +@torch_funcify.register(AdvancedIncSubtensor1) +def torch_funcify_IncSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + if getattr(op, "set_instead_of_inc", False): + + def torch_fn(x, indices, y): + x[indices] = y + return x + + else: + + def torch_fn(x, indices, y): + x[indices] += y + return x + + def incsubtensor(x, y, *ilist, torch_fn=torch_fn, idx_list=idx_list): + indices = indices_from_subtensor(ilist, idx_list) + if len(indices) == 1: + indices = indices[0] + + return torch_fn(x, indices, y) + + return incsubtensor + + +@torch_funcify.register(AdvancedIncSubtensor) +def torch_funcify_AdvancedIncSubtensor(op, node, **kwargs): + if getattr(op, "set_instead_of_inc", False): + + def torch_fn(x, indices, y): + x[indices] = y + return x + + else: + + def torch_fn(x, indices, y): + x[indices] += y + return x + + def advancedincsubtensor(x, y, *ilist, torch_fn=torch_fn): + return torch_fn(x, ilist, y) + + return advancedincsubtensor + + +@torch_funcify.register(MakeSlice) +def torch_funcify_MakeSlice(op, **kwargs): + def makeslice(*x): + return slice(*x) + + return makeslice \ No newline at end of file diff --git a/pytensor/link/pytorch/dispatch/tensor_basic.py b/pytensor/link/pytorch/dispatch/tensor_basic.py new file mode 100644 index 0000000000..8f8c0b9ef0 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/tensor_basic.py @@ -0,0 +1,207 @@ +import warnings + +import torch +import numpy as np + +from pytensor.graph.basic import Constant +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor import get_vector_length +from pytensor.tensor.basic import ( + Alloc, + AllocEmpty, + ARange, + ExtractDiag, + Eye, + Join, + MakeVector, + ScalarFromTensor, + Split, + TensorFromScalar, + Tri, + get_underlying_scalar_constant_value, +) +from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.shape import Shape_i + + +ARANGE_CONCRETE_VALUE_ERROR = """PyTorch requires the arguments of `torch.arange` +to be constants. The graph that you defined thus cannot be JIT-compiled +by PyTorch. An example of a graph that can be compiled to PyTorch: +>>> import pytensor.tensor basic +>>> at.arange(1, 10, 2) +""" + + +@pytorch_funcify.register(AllocEmpty) +def pytorch_funcify_AllocEmpty(op, **kwargs): + def allocempty(*shape): + return torch.empty(shape, dtype=op.dtype) + + return allocempty + + +@pytorch_funcify.register(Alloc) +def pytorch_funcify_Alloc(op, node, **kwargs): + def alloc(x, *shape): + res = torch.broadcast_to(x, shape) + Alloc._check_runtime_broadcast(node, torch.as_tensor(x), res.shape) + return res + + return alloc + + +@pytorch_funcify.register(ARange) +def pytorch_funcify_ARange(op, node, **kwargs): + """Register a PyTorch implementation for `ARange`. + + `torch.arange` requires concrete values for its arguments. Here we check + that the arguments are constant, and raise otherwise. + + TODO: Handle other situations in which values are concrete (shape of an array). + + """ + arange_args = node.inputs + constant_args = [] + for arg in arange_args: + if arg.owner and isinstance(arg.owner.op, Shape_i): + constant_args.append(None) + elif isinstance(arg, Constant): + constant_args.append(arg.value) + else: + # TODO: This might be failing without need (e.g., if arg = shape(x)[-1] + 1)! + raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR) + + constant_start, constant_stop, constant_step = constant_args + + def arange(start, stop, step): + start = start if constant_start is None else constant_start + stop = stop if constant_stop is None else constant_stop + step = step if constant_step is None else constant_step + return torch.arange(start, stop, step, dtype=op.dtype) + + return arange + + +@pytorch_funcify.register(Join) +def pytorch_funcify_Join(op, **kwargs): + def join(axis, *tensors): + # tensors could also be tuples, and in this case they don't have a ndim + tensors = [torch.as_tensor(tensor) for tensor in tensors] + view = op.view + if (view != -1) and all( + tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :] + ): + return tensors[view] + + else: + return torch.cat(tensors, dim=axis) + + return join + + +@pytorch_funcify.register(Split) +def pytorch_funcify_Split(op: Split, node, **kwargs): + _, axis, splits = node.inputs + try: + constant_axis = get_underlying_scalar_constant_value(axis) + except NotScalarConstantError: + constant_axis = None + warnings.warn( + "Split node does not have constant axis. PyTorch implementation will likely fail" + ) + + try: + constant_splits = np.array( + [ + get_underlying_scalar_constant_value(splits[i]) + for i in range(get_vector_length(splits)) + ] + ) + except (ValueError, NotScalarConstantError): + constant_splits = None + warnings.warn( + "Split node does not have constant split positions. PyTorch implementation will likely fail" + ) + + def split(x, axis, splits): + if constant_axis is not None: + axis = constant_axis + if constant_splits is not None: + splits = constant_splits + cumsum_splits = np.cumsum(splits[:-1]) + else: + cumsum_splits = torch.cumsum(splits[:-1]) + + if len(splits) != op.len_splits: + raise ValueError("Length of splits is not equal to n_splits") + if np.sum(splits) != x.shape[axis]: + raise ValueError( + f"Split sizes do not sum up to input length along axis: {x.shape[axis]}" + ) + if np.any(splits < 0): + raise ValueError("Split sizes cannot be negative") + + return torch.split(x, cumsum_splits, axis=axis) + + return split + + +@pytorch_funcify.register(ExtractDiag) +def pytorch_funcify_ExtractDiag(op, **kwargs): + offset = op.offset + axis1 = op.axis1 + axis2 = op.axis2 + + def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): + return torch.diagonal(x, offset=offset, dim1=axis1, dim2=axis2) + + return extract_diag + + +@pytorch_funcify.register(Eye) +def pytorch_funcify_Eye(op, **kwargs): + dtype = op.dtype + + def eye(N, M, k): + return torch.eye(N, M, k, dtype=dtype) + + return eye + + +@pytorch_funcify.register(MakeVector) +def pytorch_funcify_MakeVector(op, **kwargs): + def makevector(*x): + return torch.tensor(x, dtype=op.dtype) + + return makevector + + +@pytorch_funcify.register(TensorFromScalar) +def pytorch_funcify_TensorFromScalar(op, **kwargs): + def tensor_from_scalar(x): + return x + + return tensor_from_scalar + + +@pytorch_funcify.register(ScalarFromTensor) +def pytorch_funcify_ScalarFromTensor(op, **kwargs): + def scalar_from_tensor(x): + return torch.tensor(x).flatten()[0] + + return scalar_from_tensor + + +@pytorch_funcify.register(Tri) +def pytorch_funcify_Tri(op, node, **kwargs): + # node.inputs is N, M, k + const_args = [getattr(x, "data", None) for x in node.inputs] + + def tri(*args): + # args is N, M, k + args = [ + x if const_x is None else const_x for x, const_x in zip(args, const_args) + ] + return torch.tri(*args, dtype=op.dtype) + + return tri \ No newline at end of file From bd2b93ebbd2343e33b6a522b00b27f0f5e38115a Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Mon, 25 Sep 2023 17:45:08 +0200 Subject: [PATCH 2/4] Add tests. --- tests/link/pytorch/test_basic.py | 211 +++++++ tests/link/pytorch/test_elemwise.py | 139 +++++ tests/link/pytorch/test_extra_ops.py | 102 ++++ tests/link/pytorch/test_nlinalg.py | 176 ++++++ tests/link/pytorch/test_random.py | 782 ++++++++++++++++++++++++ tests/link/pytorch/test_scalar.py | 280 +++++++++ tests/link/pytorch/test_scan.py | 429 +++++++++++++ tests/link/pytorch/test_shape.py | 88 +++ tests/link/pytorch/test_slinalg.py | 131 ++++ tests/link/pytorch/test_sparse.py | 75 +++ tests/link/pytorch/test_subtensor.py | 251 ++++++++ tests/link/pytorch/test_tensor_basic.py | 239 ++++++++ 12 files changed, 2903 insertions(+) create mode 100644 tests/link/pytorch/test_basic.py create mode 100644 tests/link/pytorch/test_elemwise.py create mode 100644 tests/link/pytorch/test_extra_ops.py create mode 100644 tests/link/pytorch/test_nlinalg.py create mode 100644 tests/link/pytorch/test_random.py create mode 100644 tests/link/pytorch/test_scalar.py create mode 100644 tests/link/pytorch/test_scan.py create mode 100644 tests/link/pytorch/test_shape.py create mode 100644 tests/link/pytorch/test_slinalg.py create mode 100644 tests/link/pytorch/test_sparse.py create mode 100644 tests/link/pytorch/test_subtensor.py create mode 100644 tests/link/pytorch/test_tensor_basic.py diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py new file mode 100644 index 0000000000..aa39bd2836 --- /dev/null +++ b/tests/link/pytorch/test_basic.py @@ -0,0 +1,211 @@ +from functools import partial +from typing import Callable, Iterable, Optional + +import numpy as np +import pytest + +from pytensor.compile.function import function +from pytensor.compile.mode import get_mode +from pytensor.compile.sharedvalue import SharedVariable, shared +from pytensor.configdefaults import config +from pytensor.graph.basic import Apply +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import Op, get_test_value +from pytensor.ifelse import ifelse +from pytensor.raise_op import assert_op +from pytensor.tensor.type import dscalar, scalar, vector + + +@pytest.fixture(scope="module", autouse=True) +def set_pytensor_flags(): + with config.change_flags(cxx="", compute_test_value="ignore"): + yield + + +pytorch = pytest.importorskip("torch") + + +# We assume that the PyTorch mode includes all the rewrites needed to transpile PyTorch graphs +pytorch_mode = get_mode("PYTORCH") +py_mode = get_mode("FAST_COMPILE") + + +def compare_pytorch_and_py( + fgraph: FunctionGraph, + test_inputs: Iterable, + assert_fn: Optional[Callable] = None, + must_be_device_array: bool = True, + pytorch_mode=pytorch_mode, + py_mode=py_mode, +): + """Function to compare python graph output and pytorch compiled output for testing equality + + In the tests below computational graphs are defined in PyTensor. These graphs are then passed to + this function which then compiles the graphs in both pytorch and python, runs the calculation + in both and checks if the results are the same + + Parameters + ---------- + fgraph: FunctionGraph + PyTensor function Graph object + test_inputs: iter + Numerical inputs for testing the function graph + assert_fn: func, opt + Assert function used to check for equality between python and pytorch. If not + provided uses np.testing.assert_allclose + must_be_device_array: Bool + Checks for instance of pytorch.interpreters.xla.DeviceArray. For testing purposes + if this device array is found it indicates if the result was computed by pytorch + + Returns + ------- + pytorch_res + + """ + if assert_fn is None: + assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) + + fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] + pytensor_pytorch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode) + pytorch_res = pytensor_pytorch_fn(*test_inputs) + + if must_be_device_array: + if isinstance(pytorch_res, list): + assert all(isinstance(res, pytorch.Array) for res in pytorch_res) + else: + assert isinstance(pytorch_res, pytorch.interpreters.xla.DeviceArray) + + pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) + py_res = pytensor_py_fn(*test_inputs) + + if len(fgraph.outputs) > 1: + for j, p in zip(pytorch_res, py_res): + assert_fn(j, p) + else: + assert_fn(pytorch_res, py_res) + + return pytensor_pytorch_fn, pytorch_res + + +def test_pytorch_FunctionGraph_once(): + """Make sure that an output is only computed once when it's referenced multiple times.""" + from pytensor.link.pytorch.dispatch import pytorch_funcify + + x = vector("x") + y = vector("y") + + class TestOp(Op): + def __init__(self): + self.called = 0 + + def make_node(self, *args): + return Apply(self, list(args), [x.type() for x in args]) + + def perform(self, inputs, outputs): + for i, inp in enumerate(inputs): + outputs[i][0] = inp[0] + + @pytorch_funcify.register(TestOp) + def pytorch_funcify_TestOp(op, **kwargs): + def func(*args, op=op): + op.called += 1 + return list(args) + + return func + + op1 = TestOp() + op2 = TestOp() + + q, r = op1(x, y) + outs = op2(q + r, q + r) + + out_fg = FunctionGraph([x, y], outs, clone=False) + assert len(out_fg.outputs) == 2 + + out_jx = pytorch_funcify(out_fg) + + x_val = np.r_[1, 2].astype(config.floatX) + y_val = np.r_[2, 3].astype(config.floatX) + + res = out_jx(x_val, y_val) + assert len(res) == 2 + assert op1.called == 1 + assert op2.called == 1 + + res = out_jx(x_val, y_val) + assert len(res) == 2 + assert op1.called == 2 + assert op2.called == 2 + + +def test_shared(): + a = shared(np.array([1, 2, 3], dtype=config.floatX)) + + pytensor_pytorch_fn = function([], a, mode="PyTorch") + pytorch_res = pytensor_pytorch_fn() + + assert isinstance(pytorch_res, pytorch.Array) + np.testing.assert_allclose(pytorch_res, a.get_value()) + + pytensor_pytorch_fn = function([], a * 2, mode="PyTorch") + pytorch_res = pytensor_pytorch_fn() + + assert isinstance(pytorch_res, pytorch.Array) + np.testing.assert_allclose(pytorch_res, a.get_value() * 2) + + # Changed the shared value and make sure that the PyTorch-compiled + # function also changes. + new_a_value = np.array([3, 4, 5], dtype=config.floatX) + a.set_value(new_a_value) + + pytorch_res = pytensor_pytorch_fn() + assert isinstance(pytorch_res, pytorch.Array) + np.testing.assert_allclose(pytorch_res, new_a_value * 2) + + +def test_shared_updates(): + a = shared(0) + + pytensor_pytorch_fn = function([], a, updates={a: a + 1}, mode="PyTorch") + res1, res2 = pytensor_pytorch_fn(), pytensor_pytorch_fn() + assert res1 == 0 + assert res2 == 1 + assert a.get_value() == 2 + + a.set_value(5) + res1, res2 = pytensor_pytorch_fn(), pytensor_pytorch_fn() + assert res1 == 5 + assert res2 == 6 + assert a.get_value() == 7 + + +def test_pytorch_ifelse(): + true_vals = np.r_[1, 2, 3] + false_vals = np.r_[-1, -2, -3] + + x = ifelse(np.array(True), true_vals, false_vals) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) + + a = dscalar("a") + a.tag.test_value = np.array(0.2, dtype=config.floatX) + x = ifelse(a < 0.5, true_vals, false_vals) + x_fg = FunctionGraph([a], [x]) # I.e. False + + compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs]) + + +def test_pytorch_checkandraise(): + p = scalar() + p.tag.test_value = 0 + + res = assert_op(p, p < 1.0) + + with pytest.warns(UserWarning): + function((p,), res, mode=pytorch_mode) + + +def set_test_value(x, v): + x.tag.test_value = v + return x diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py new file mode 100644 index 0000000000..a7816928fa --- /dev/null +++ b/tests/link/pytorch/test_elemwise.py @@ -0,0 +1,139 @@ +import numpy as np +import pytest +import scipy.special + +import pytensor +import pytensor.tensor as at +from pytensor.compile import get_mode +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.tensor import elemwise as at_elemwise +from pytensor.tensor.math import all as at_all +from pytensor.tensor.math import prod +from pytensor.tensor.math import sum as at_sum +from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax +from pytensor.tensor.type import matrix, tensor, vector, vectors +from tests.link.pytorch.test_basic import compare_pytorch_and_py +from tests.tensor.test_elemwise import TestElemwise + + +def test_elemwise_runtime_broadcast(): + TestElemwise.check_runtime_broadcast(get_mode("PyTorch")) + + +def test_pytorch_Dimshuffle(): + a_at = matrix("a") + + x = a_at.T + x_fg = FunctionGraph([a_at], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + + x = a_at.dimshuffle([0, 1, "x"]) + x_fg = FunctionGraph([a_at], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + + a_at = tensor(dtype=config.floatX, shape=(None, 1)) + x = a_at.dimshuffle((0,)) + x_fg = FunctionGraph([a_at], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + + a_at = tensor(dtype=config.floatX, shape=(None, 1)) + x = at_elemwise.DimShuffle([False, True], (0,))(a_at) + x_fg = FunctionGraph([a_at], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + + +def test_pytorch_CAReduce(): + a_at = vector("a") + a_at.tag.test_value = np.r_[1, 2, 3].astype(config.floatX) + + x = at_sum(a_at, axis=None) + x_fg = FunctionGraph([a_at], [x]) + + compare_pytorch_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)]) + + a_at = matrix("a") + a_at.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) + + x = at_sum(a_at, axis=0) + x_fg = FunctionGraph([a_at], [x]) + + compare_pytorch_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + + x = at_sum(a_at, axis=1) + x_fg = FunctionGraph([a_at], [x]) + + compare_pytorch_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + + a_at = matrix("a") + a_at.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) + + x = prod(a_at, axis=0) + x_fg = FunctionGraph([a_at], [x]) + + compare_pytorch_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + + x = at_all(a_at) + x_fg = FunctionGraph([a_at], [x]) + + compare_pytorch_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_softmax(axis): + x = matrix("x") + x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + out = softmax(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_logsoftmax(axis): + x = matrix("x") + x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + out = log_softmax(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_softmax_grad(axis): + dy = matrix("dy") + dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) + sm = matrix("sm") + sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + out = SoftmaxGrad(axis=axis)(dy, sm) + fgraph = FunctionGraph([dy, sm], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)]) +@pytest.mark.parametrize("axis", [0, 1]) +def test_logsumexp_benchmark(size, axis, benchmark): + X = at.matrix("X") + X_max = at.max(X, axis=axis, keepdims=True) + X_max = at.switch(at.isinf(X_max), 0, X_max) + X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max + + rng = np.random.default_rng(23920) + X_val = rng.normal(size=size) + + X_lse_fn = pytensor.function([X], X_lse, mode="PyTorch") + + # JIT compile first + _ = X_lse_fn(X_val) + + res = benchmark(X_lse_fn, X_val) + + exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) + np.testing.assert_array_almost_equal(res, exp_res) + + +def test_multiple_input_multiply(): + x, y, z = vectors("xyz") + out = at.mul(x, y, z) + + fg = FunctionGraph(outputs=[out], clone=False) + compare_pytorch_and_py(fg, [[1.5], [2.5], [3.5]]) diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py new file mode 100644 index 0000000000..afa5a0aff5 --- /dev/null +++ b/tests/link/pytorch/test_extra_ops.py @@ -0,0 +1,102 @@ +import numpy as np +import pytest +from packaging.version import parse as version_parse + +import pytensor.tensor.basic as at +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.tensor import extra_ops as at_extra_ops +from pytensor.tensor.type import matrix +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +pytorch = pytest.importorskip("torch") + + +def set_test_value(x, v): + x.tag.test_value = v + return x + + +def test_extra_ops(): + a = matrix("a") + a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) + + out = at_extra_ops.cumsum(a, axis=0) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = at_extra_ops.cumprod(a, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = at_extra_ops.diff(a, n=2, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = at_extra_ops.repeat(a, (3, 3), axis=1) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + c = at.as_tensor(5) + + out = at_extra_ops.fill_diagonal(a, c) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + with pytest.raises(NotImplementedError): + out = at_extra_ops.fill_diagonal_offset(a, c, c) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + with pytest.raises(NotImplementedError): + out = at_extra_ops.Unique(axis=1)(a) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + indices = np.arange(np.prod((3, 4))) + out = at_extra_ops.unravel_index(indices, (3, 4), order="C") + fgraph = FunctionGraph([], out) + compare_pytorch_and_py( + fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False + ) + + +@pytest.mark.xfail( + version_parse(pytorch.__version__) >= version_parse("0.2.12"), + reason="Omnistaging cannot be disabled", +) +def test_extra_ops_omni(): + a = matrix("a") + a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) + + # This function also cannot take symbolic input. + c = at.as_tensor(5) + out = at_extra_ops.bartlett(c) + fgraph = FunctionGraph([], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + multi_index = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4)) + out = at_extra_ops.ravel_multi_index(multi_index, (3, 4)) + fgraph = FunctionGraph([], [out]) + compare_pytorch_and_py( + fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False + ) + + # The inputs are "concrete", yet it still has problems? + out = at_extra_ops.Unique()( + at.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2))) + ) + fgraph = FunctionGraph([], [out]) + compare_pytorch_and_py(fgraph, []) + + +@pytest.mark.xfail(reason="pytorch.numpy.arange requires concrete inputs") +def test_unique_nonconcrete(): + a = matrix("a") + a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) + + out = at_extra_ops.Unique()(a) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py new file mode 100644 index 0000000000..78ace9a049 --- /dev/null +++ b/tests/link/pytorch/test_nlinalg.py @@ -0,0 +1,176 @@ +import numpy as np +import pytest +from packaging.version import parse as version_parse + +from pytensor.compile.function import function +from pytensor.compile.mode import Mode +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.link.pytorch import PyTorchLinker +from pytensor.tensor import blas as at_blas +from pytensor.tensor import nlinalg as at_nlinalg +from pytensor.tensor.math import MaxAndArgmax +from pytensor.tensor.math import max as at_max +from pytensor.tensor.math import maximum +from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +pytorch = pytest.importorskip("torch") + + +def test_pytorch_BatchedDot(): + # tensor3 . tensor3 + a = tensor3("a") + a.tag.test_value = ( + np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) + ) + b = tensor3("b") + b.tag.test_value = ( + np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) + ) + out = at_blas.BatchedDot()(a, b) + fgraph = FunctionGraph([a, b], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + # A dimension mismatch should raise a TypeError for compatibility + inputs = [get_test_value(a)[:-1], get_test_value(b)] + opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) + pytorch_mode = Mode(PyTorchLinker(), opts) + pytensor_pytorch_fn = function(fgraph.inputs, fgraph.outputs, mode=pytorch_mode) + with pytest.raises(TypeError): + pytensor_pytorch_fn(*inputs) + + # matrix . matrix + a = matrix("a") + a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape((5, 3)) + b = matrix("b") + b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape((5, 3)) + out = at_blas.BatchedDot()(a, b) + fgraph = FunctionGraph([a, b], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_pytorch_basic_multiout(): + rng = np.random.default_rng(213234) + + M = rng.normal(size=(3, 3)) + X = M.dot(M.T) + + x = matrix("x") + + outs = at_nlinalg.eig(x) + out_fg = FunctionGraph([x], outs) + + def assert_fn(x, y): + np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) + + compare_pytorch_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + outs = at_nlinalg.eigh(x) + out_fg = FunctionGraph([x], outs) + compare_pytorch_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + outs = at_nlinalg.qr(x, mode="full") + out_fg = FunctionGraph([x], outs) + compare_pytorch_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + outs = at_nlinalg.qr(x, mode="reduced") + out_fg = FunctionGraph([x], outs) + compare_pytorch_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + outs = at_nlinalg.svd(x) + out_fg = FunctionGraph([x], outs) + compare_pytorch_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + outs = at_nlinalg.slogdet(x) + out_fg = FunctionGraph([x], outs) + compare_pytorch_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + +@pytest.mark.xfail( + version_parse(pytorch.__version__) >= version_parse("0.2.12"), + reason="Omnistaging cannot be disabled", +) +def test_pytorch_basic_multiout_omni(): + # Test that a single output of a multi-output `Op` can be used as input to + # another `Op` + x = dvector() + mx, amx = MaxAndArgmax([0])(x) + out = mx * amx + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(out_fg, [np.r_[1, 2]]) + + +@pytest.mark.xfail( + version_parse(pytorch.__version__) >= version_parse("0.2.12"), + reason="Omnistaging cannot be disabled", +) +def test_tensor_basics(): + y = vector("y") + y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + x = vector("x") + x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) + A = matrix("A") + A.tag.test_value = np.empty((2, 2), dtype=config.floatX) + alpha = scalar("alpha") + alpha.tag.test_value = np.array(3.0, dtype=config.floatX) + beta = scalar("beta") + beta.tag.test_value = np.array(5.0, dtype=config.floatX) + + # This should be converted into a `Gemv` `Op` when the non-PyTorch compatible + # optimizations are turned on; however, when using PyTorch mode, it should + # leave the expression alone. + out = y.dot(alpha * A).dot(x) + beta * y + fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = maximum(y, x) + fgraph = FunctionGraph([y, x], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = at_max(y) + fgraph = FunctionGraph([y], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_pinv(): + x = matrix("x") + x_inv = at_nlinalg.pinv(x) + + fgraph = FunctionGraph([x], [x_inv]) + x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + compare_pytorch_and_py(fgraph, [x_np]) + + +def test_pinv_hermitian(): + A = matrix("A", dtype="complex128") + A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]] + A_not_h_test = A_h_test + 0 + 1j + + A_inv = at_nlinalg.pinv(A, hermitian=False) + pytorch_fn = function([A], A_inv, mode="PyTorch") + + assert np.allclose(pytorch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False)) + assert np.allclose(pytorch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True)) + assert np.allclose( + pytorch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False) + ) + assert not np.allclose( + pytorch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True) + ) + + A_inv = at_nlinalg.pinv(A, hermitian=True) + pytorch_fn = function([A], A_inv, mode="PyTorch") + + assert np.allclose(pytorch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False)) + assert np.allclose(pytorch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True)) + assert not np.allclose( + pytorch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False) + ) + # Numpy fails differently than PyTorch when hermitian assumption is violated + assert not np.allclose( + pytorch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True) + ) diff --git a/tests/link/pytorch/test_random.py b/tests/link/pytorch/test_random.py new file mode 100644 index 0000000000..f7ec310227 --- /dev/null +++ b/tests/link/pytorch/test_random.py @@ -0,0 +1,782 @@ +import numpy as np +import pytest +import scipy.stats as stats + +import pytensor +import pytensor.tensor as at +import pytensor.tensor.random as aer +from pytensor.compile.function import function +from pytensor.compile.sharedvalue import SharedVariable, shared +from pytensor.graph.basic import Constant +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor.random.basic import RandomVariable +from pytensor.tensor.random.utils import RandomStream +from tests.link.pytorch.test_basic import compare_pytorch_and_py, pytorch_mode, set_test_value + + +pytorch = pytest.importorskip("torch") + + +from pytensor.link.pytorch.dispatch.random import numpyro_available # noqa: E402 + + +def random_function(*args, **kwargs): + with pytest.warns( + UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" + ): + return function(*args, **kwargs) + + +def test_random_RandomStream(): + """Two successive calls of a compiled graph using `RandomStream` should + return different values. + + """ + srng = RandomStream(seed=123) + out = srng.normal() - srng.normal() + + fn = random_function([], out, mode=pytorch_mode) + pytorch_res_1 = fn() + pytorch_res_2 = fn() + + assert not np.array_equal(pytorch_res_1, pytorch_res_2) + + +@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng)) +def test_random_updates(rng_ctor): + original_value = rng_ctor(seed=98) + rng = shared(original_value, name="original_rng", borrow=False) + next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs + + f = random_function([], [x], updates={rng: next_rng}, mode=pytorch_mode) + assert f() != f() + + # Check that original rng variable content was not overwritten when calling pytorch_typify + assert all( + a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b) + for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__()) + ) + + +def test_random_updates_input_storage_order(): + """Test case described in issue #314. + + This happened when we tried to update the input storage after we clone the shared RNG. + We used to call `input_storage.index(old_input_storage)` which would fail when the input_storage contained + numpy arrays before the RNG value, which would fail the equality check. + + """ + pt_rng = RandomStream(1) + + batchshape = (3, 1, 4, 4) + inp_shared = pytensor.shared( + np.zeros(batchshape, dtype="float64"), name="inp_shared" + ) + + inp = at.tensor4(dtype="float64", name="inp") + inp_update = inp + pt_rng.normal(size=inp.shape, loc=5, scale=1e-5) + + # This function replaces inp by input_shared in the update expression + # This is what caused the RNG to appear later than inp_shared in the input_storage + + fn = random_function( + inputs=[], + outputs=[], + updates={inp_shared: inp_update}, + givens={inp: inp_shared}, + mode="PyTorch", + ) + fn() + np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3) + fn() + np.testing.assert_allclose(inp_shared.get_value(), 10, rtol=1e-3) + + +@pytest.mark.parametrize( + "rv_op, dist_params, base_size, cdf_name, params_conv", + [ + ( + aer.beta, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "beta", + lambda *args: args, + ), + ( + aer.cauchy, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "cauchy", + lambda *args: args, + ), + ( + aer.exponential, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + ], + (2,), + "expon", + lambda *args: (0, args[0]), + ), + ( + aer.gamma, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "gamma", + lambda a, b: (a, 0.0, b), + ), + ( + aer.gumbel, + [ + set_test_value( + at.lvector(), + np.array([1, 2], dtype=np.int64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "gumbel_r", + lambda *args: args, + ), + ( + aer.laplace, + [ + set_test_value(at.dvector(), np.array([1.0, 2.0], dtype=np.float64)), + set_test_value(at.dscalar(), np.array(1.0, dtype=np.float64)), + ], + (2,), + "laplace", + lambda *args: args, + ), + ( + aer.logistic, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "logistic", + lambda *args: args, + ), + ( + aer.lognormal, + [ + set_test_value( + at.lvector(), + np.array([0, 0], dtype=np.int64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "lognorm", + lambda mu, sigma: (sigma, 0, np.exp(mu)), + ), + ( + aer.normal, + [ + set_test_value( + at.lvector(), + np.array([1, 2], dtype=np.int64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "norm", + lambda *args: args, + ), + ( + aer.pareto, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ) + ], + (2,), + "pareto", + lambda *args: args, + ), + ( + aer.poisson, + [ + set_test_value( + at.dvector(), + np.array([100000.0, 200000.0], dtype=np.float64), + ), + ], + (2,), + "poisson", + lambda *args: args, + ), + ( + aer.randint, + [ + set_test_value( + at.lscalar(), + np.array(0, dtype=np.int64), + ), + set_test_value( # high-value necessary since test on cdf + at.lscalar(), + np.array(1000, dtype=np.int64), + ), + ], + (), + "randint", + lambda *args: args, + ), + ( + aer.integers, + [ + set_test_value( + at.lscalar(), + np.array(0, dtype=np.int64), + ), + set_test_value( # high-value necessary since test on cdf + at.lscalar(), + np.array(1000, dtype=np.int64), + ), + ], + (), + "randint", + lambda *args: args, + ), + ( + aer.standard_normal, + [], + (2,), + "norm", + lambda *args: args, + ), + ( + aer.t, + [ + set_test_value( + at.dscalar(), + np.array(2.0, dtype=np.float64), + ), + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "t", + lambda *args: args, + ), + ( + aer.uniform, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1000.0, dtype=np.float64), + ), + ], + (2,), + "uniform", + lambda *args: args, + ), + ( + aer.halfnormal, + [ + set_test_value( + at.dvector(), + np.array([-1.0, 200.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1000.0, dtype=np.float64), + ), + ], + (2,), + "halfnorm", + lambda *args: args, + ), + ( + aer.invgamma, + [ + set_test_value( + at.dvector(), + np.array([10.4, 2.8], dtype=np.float64), + ), + set_test_value( + at.dvector(), + np.array([3.4, 7.3], dtype=np.float64), + ), + ], + (2,), + "invgamma", + lambda a, b: (a, 0, b), + ), + ( + aer.chisquare, + [ + set_test_value( + at.dvector(), + np.array([2.4, 4.9], dtype=np.float64), + ), + ], + (2,), + "chi2", + lambda *args: args, + ), + ( + aer.gengamma, + [ + set_test_value( + at.dvector(), + np.array([10.4, 2.8], dtype=np.float64), + ), + set_test_value( + at.dvector(), + np.array([3.4, 7.3], dtype=np.float64), + ), + set_test_value( + at.dvector(), + np.array([0.9, 2.0], dtype=np.float64), + ), + ], + (2,), + "gengamma", + lambda alpha, p, lambd: (alpha / p, p, 0, lambd), + ), + ( + aer.wald, + [ + set_test_value( + at.dvector(), + np.array([10.4, 2.8], dtype=np.float64), + ), + set_test_value( + at.dvector(), + np.array([4.5, 2.0], dtype=np.float64), + ), + ], + (2,), + "invgauss", + # https://stackoverflow.com/a/48603469 + lambda mean, scale: (mean / scale, 0, scale), + ), + pytest.param( + aer.vonmises, + [ + set_test_value( + at.dvector(), + np.array([-0.5, 1.3], dtype=np.float64), + ), + set_test_value( + at.dvector(), + np.array([5.5, 13.0], dtype=np.float64), + ), + ], + (2,), + "vonmises", + lambda mu, kappa: (kappa, mu), + marks=pytest.mark.skipif( + not numpyro_available, reason="VonMises dispatch requires numpyro" + ), + ), + ], +) +def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv): + """The PyTorch samplers are not one-to-one with NumPy samplers so we + need to use a statistical test to make sure that the transpilation + is correct. + + Parameters + ---------- + rv_op + The transpiled `RandomVariable` `Op`. + dist_params + The parameters passed to the op. + + """ + if rv_op is aer.integers: + # Integers only accepts Generator, not RandomState + rng = shared(np.random.default_rng(29402)) + else: + rng = shared(np.random.RandomState(29402)) + g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng) + g_fn = random_function(dist_params, g, mode=pytorch_mode) + samples = g_fn( + *[ + i.tag.test_value + for i in g_fn.maker.fgraph.inputs + if not isinstance(i, (SharedVariable, Constant)) + ] + ) + + bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_params]) + + for idx in np.ndindex(*base_size): + cdf_params = params_conv(*tuple(arg[idx] for arg in bcast_dist_args)) + test_res = stats.cramervonmises( + samples[(Ellipsis,) + idx], cdf_name, args=cdf_params + ) + assert not np.isnan(test_res.statistic) + assert test_res.pvalue > 0.01 + + +@pytest.mark.parametrize("size", [(), (4,)]) +def test_random_bernoulli(size): + rng = shared(np.random.RandomState(123)) + g = at.random.bernoulli(0.5, size=(1000,) + size, rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) + + +def test_random_mvnormal(): + rng = shared(np.random.RandomState(123)) + + mu = np.ones(4) + cov = np.eye(4) + g = at.random.multivariate_normal(mu, cov, size=(10000,), rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1) + + +@pytest.mark.parametrize( + "parameter, size", + [ + (np.ones(4), ()), + (np.ones(4), (2, 4)), + ], +) +def test_random_dirichlet(parameter, size): + rng = shared(np.random.RandomState(123)) + g = at.random.dirichlet(parameter, size=(1000,) + size, rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) + + +def test_random_choice(): + # Elements are picked at equal frequency + num_samples = 10000 + rng = shared(np.random.RandomState(123)) + g = at.random.choice(np.arange(4), size=num_samples, rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2) + + # `replace=False` produces unique results + rng = shared(np.random.RandomState(123)) + g = at.random.choice(np.arange(100), replace=False, size=99, rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + assert len(np.unique(samples)) == 99 + + # We can pass an array with probabilities + rng = shared(np.random.RandomState(123)) + g = at.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples, np.zeros(10)) + + +def test_random_categorical(): + rng = shared(np.random.RandomState(123)) + g = at.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1) + + +def test_random_permutation(): + array = np.arange(4) + rng = shared(np.random.RandomState(123)) + g = at.random.permutation(array, rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + permuted = g_fn() + with pytest.raises(AssertionError): + np.testing.assert_allclose(array, permuted) + + +def test_random_geometric(): + rng = shared(np.random.RandomState(123)) + p = np.array([0.3, 0.7]) + g = at.random.geometric(p, size=(10_000, 2), rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1) + np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1) + + +def test_negative_binomial(): + rng = shared(np.random.RandomState(123)) + n = np.array([10, 40]) + p = np.array([0.3, 0.7]) + g = at.random.negative_binomial(n, p, size=(10_000, 2), rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1) + np.testing.assert_allclose( + samples.std(axis=0), np.sqrt(n * (1 - p) / p**2), rtol=0.1 + ) + + +@pytest.mark.skipif(not numpyro_available, reason="Binomial dispatch requires numpyro") +def test_binomial(): + rng = shared(np.random.RandomState(123)) + n = np.array([10, 40]) + p = np.array([0.3, 0.7]) + g = at.random.binomial(n, p, size=(10_000, 2), rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) + np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1) + + +@pytest.mark.skipif( + not numpyro_available, reason="BetaBinomial dispatch requires numpyro" +) +def test_beta_binomial(): + rng = shared(np.random.RandomState(123)) + n = np.array([10, 40]) + a = np.array([1.5, 13]) + b = np.array([0.5, 9]) + g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1) + np.testing.assert_allclose( + samples.std(axis=0), + np.sqrt((n * a * b * (a + b + n)) / ((a + b) ** 2 * (a + b + 1))), + rtol=0.1, + ) + + +@pytest.mark.skipif( + not numpyro_available, reason="Multinomial dispatch requires numpyro" +) +def test_multinomial(): + rng = shared(np.random.RandomState(123)) + n = np.array([10, 40]) + p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) + g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) + np.testing.assert_allclose( + samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1 + ) + + +@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro") +def test_vonmises_mu_outside_circle(): + # Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle + # We test that the random draws from the PyTorch dispatch work as expected in these cases + rng = shared(np.random.RandomState(123)) + mu = np.array([-30, 40]) + kappa = np.array([100, 10]) + g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng) + g_fn = random_function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose( + samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1 + ) + + # Circvar only does the correct thing in more recent versions of Scipy + # https://github.com/scipy/scipy/pull/5747 + # np.testing.assert_allclose( + # stats.circvar(samples, axis=0), + # 1 - special.iv(1, kappa) / special.iv(0, kappa), + # rtol=0.1, + # ) + + # For now simple compare with std from numpy draws + rng = np.random.default_rng(123) + ref_samples = rng.vonmises(mu, kappa, size=(10_000, 2)) + np.testing.assert_allclose( + np.std(samples, axis=0), np.std(ref_samples, axis=0), rtol=0.1 + ) + + +def test_random_unimplemented(): + """Compiling a graph with a non-supported `RandomVariable` should + raise an error. + + """ + + class NonExistentRV(RandomVariable): + name = "non-existent" + ndim_supp = 0 + ndims_params = [] + dtype = "floatX" + + def __call__(self, size=None, **kwargs): + return super().__call__(size=size, **kwargs) + + def rng_fn(cls, rng, size): + return 0 + + nonexistentrv = NonExistentRV() + rng = shared(np.random.RandomState(123)) + out = nonexistentrv(rng=rng) + fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) + + with pytest.raises(NotImplementedError): + with pytest.warns( + UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" + ): + compare_pytorch_and_py(fgraph, []) + + +def test_random_custom_implementation(): + """We can register a PyTorch implementation for user-defined `RandomVariable`s""" + + class CustomRV(RandomVariable): + name = "non-existent" + ndim_supp = 0 + ndims_params = [] + dtype = "floatX" + + def __call__(self, size=None, **kwargs): + return super().__call__(size=size, **kwargs) + + def rng_fn(cls, rng, size): + return 0 + + from pytensor.link.pytorch.dispatch.random import pytorch_sample_fn + + @pytorch_sample_fn.register(CustomRV) + def pytorch_sample_fn_custom(op): + def sample_fn(rng, size, dtype, *parameters): + return (rng, 0) + + return sample_fn + + nonexistentrv = CustomRV() + rng = shared(np.random.RandomState(123)) + out = nonexistentrv(rng=rng) + fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) + with pytest.warns( + UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" + ): + compare_pytorch_and_py(fgraph, []) + + +def test_random_concrete_shape(): + """PyTorch should compile when a `RandomVariable` is passed a concrete shape. + + There are three quantities that PyTorch considers as concrete: + 1. Constants known at compile time; + 2. The shape of an array. + 3. `static_argnums` parameters + This test makes sure that graphs with `RandomVariable`s compile when the + `size` parameter satisfies either of these criteria. + + """ + rng = shared(np.random.RandomState(123)) + x_at = at.dmatrix() + out = at.random.normal(0, 1, size=x_at.shape, rng=rng) + pytorch_fn = random_function([x_at], out, mode=pytorch_mode) + assert pytorch_fn(np.ones((2, 3))).shape == (2, 3) + + +def test_random_concrete_shape_from_param(): + rng = shared(np.random.RandomState(123)) + x_at = at.dmatrix() + out = at.random.normal(x_at, 1, rng=rng) + pytorch_fn = random_function([x_at], out, mode=pytorch_mode) + assert pytorch_fn(np.ones((2, 3))).shape == (2, 3) + + +def test_random_concrete_shape_subtensor(): + """PyTorch should compile when a concrete value is passed for the `size` parameter. + + This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar + inputs into 1d vectors is replaced by an `Op` that turns concrete scalar + inputs into tuples of concrete values using the `pytorch_size_parameter_as_tuple` + rewrite. + + PyTorch does not accept scalars as `size` or `shape` arguments, so this is a + slight improvement over their API. + + """ + rng = shared(np.random.RandomState(123)) + x_at = at.dmatrix() + out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng) + pytorch_fn = random_function([x_at], out, mode=pytorch_mode) + assert pytorch_fn(np.ones((2, 3))).shape == (3,) + + +def test_random_concrete_shape_subtensor_tuple(): + """PyTorch should compile when a tuple of concrete values is passed for the `size` parameter. + + This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple + inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete + scalar inputs into tuples of concrete values using the + `pytorch_size_parameter_as_tuple` rewrite. + + """ + rng = shared(np.random.RandomState(123)) + x_at = at.dmatrix() + out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng) + pytorch_fn = random_function([x_at], out, mode=pytorch_mode) + assert pytorch_fn(np.ones((2, 3))).shape == (2,) + + +@pytest.mark.xfail( + reason="`size_at` should be specified as a static argument", strict=True +) +def test_random_concrete_shape_graph_input(): + rng = shared(np.random.RandomState(123)) + size_at = at.scalar() + out = at.random.normal(0, 1, size=size_at, rng=rng) + pytorch_fn = random_function([size_at], out, mode=pytorch_mode) + assert pytorch_fn(10).shape == (10,) diff --git a/tests/link/pytorch/test_scalar.py b/tests/link/pytorch/test_scalar.py new file mode 100644 index 0000000000..4e74bb6681 --- /dev/null +++ b/tests/link/pytorch/test_scalar.py @@ -0,0 +1,280 @@ +import numpy as np +import pytest + +import pytensor.scalar.basic as aes +import pytensor.tensor as at +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.scalar.basic import Composite +from pytensor.tensor import as_tensor +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.math import all as at_all +from pytensor.tensor.math import ( + cosh, + erf, + erfc, + erfcinv, + erfcx, + erfinv, + iv, + log, + log1mexp, + psi, + sigmoid, + softplus, +) +from pytensor.tensor.type import matrix, scalar, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +pytorch = pytest.importorskip("torch") +from pytensor.link.pytorch.dispatch import pytorch_funcify + + +try: + pass + + TFP_INSTALLED = True +except ModuleNotFoundError: + TFP_INSTALLED = False + + +def test_second(): + a0 = scalar("a0") + b = scalar("b") + + out = aes.second(a0, b) + fgraph = FunctionGraph([a0, b], [out]) + compare_pytorch_and_py(fgraph, [10.0, 5.0]) + + a1 = vector("a1") + out = at.second(a1, b) + fgraph = FunctionGraph([a1, b], [out]) + compare_pytorch_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0]) + + a2 = matrix("a2", shape=(1, None), dtype="float64") + b2 = matrix("b2", shape=(None, 1), dtype="int32") + out = at.second(a2, b2) + fgraph = FunctionGraph([a2, b2], [out]) + compare_pytorch_and_py( + fgraph, [np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")] + ) + + +def test_second_constant_scalar(): + b = scalar("b", dtype="int") + out = at.second(0.0, b) + fgraph = FunctionGraph([b], [out]) + # Test dispatch directly as useless second is removed during compilation + fn = pytorch_funcify(fgraph) + [res] = fn(1) + assert res == 1 + assert res.dtype == out.dtype + + +def test_identity(): + a = scalar("a") + a.tag.test_value = 10 + + out = aes.identity(a) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +@pytest.mark.parametrize( + "x, y, x_val, y_val", + [ + (scalar("x"), scalar("y"), np.array(10), np.array(20)), + (scalar("x"), vector("y"), np.array(10), np.arange(10, 20)), + ( + matrix("x"), + vector("y"), + np.arange(10 * 20).reshape((20, 10)), + np.arange(10, 20), + ), + ], +) +def test_pytorch_Composite_singe_output(x, y, x_val, y_val): + x_s = aes.float64("x") + y_s = aes.float64("y") + + comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)])) + + out = comp_op(x, y) + + out_fg = FunctionGraph([x, y], [out]) + + test_input_vals = [ + x_val.astype(config.floatX), + y_val.astype(config.floatX), + ] + _ = compare_pytorch_and_py(out_fg, test_input_vals) + + +def test_pytorch_Composite_multi_output(): + x = vector("x") + + x_s = aes.float64("xs") + outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x) + + fgraph = FunctionGraph([x], outs) + compare_pytorch_and_py(fgraph, [np.arange(10, dtype=config.floatX)]) + + +def test_erf(): + x = scalar("x") + out = erf(x) + fg = FunctionGraph([x], [out]) + + compare_pytorch_and_py(fg, [1.0]) + + +def test_erfc(): + x = scalar("x") + out = erfc(x) + fg = FunctionGraph([x], [out]) + + compare_pytorch_and_py(fg, [1.0]) + + +def test_erfinv(): + x = scalar("x") + out = erfinv(x) + fg = FunctionGraph([x], [out]) + + compare_pytorch_and_py(fg, [0.95]) + + +@pytest.mark.parametrize( + "op, test_values", + [ + (erfcx, (0.7,)), + (erfcinv, (0.7,)), + (iv, (0.3, 0.7)), + ], +) +@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability") +def test_tfp_ops(op, test_values): + inputs = [as_tensor(test_value).type() for test_value in test_values] + output = op(*inputs) + + fg = FunctionGraph(inputs, [output]) + compare_pytorch_and_py(fg, test_values) + + +def test_psi(): + x = scalar("x") + out = psi(x) + fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(fg, [3.0]) + + +def test_log1mexp(): + x = vector("x") + out = log1mexp(x) + fg = FunctionGraph([x], [out]) + + compare_pytorch_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]]) + + +def test_nnet(): + x = vector("x") + x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + + out = sigmoid(x) + fgraph = FunctionGraph([x], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = softplus(x) + fgraph = FunctionGraph([x], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_pytorch_variadic_Scalar(): + mu = vector("mu", dtype=config.floatX) + mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX) + tau = vector("tau", dtype=config.floatX) + tau.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + + res = -tau * mu + + fgraph = FunctionGraph([mu, tau], [res]) + + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + res = -tau * (tau - mu) ** 2 + + fgraph = FunctionGraph([mu, tau], [res]) + + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_add_scalars(): + x = at.matrix("x") + size = x.shape[0] + x.shape[0] + x.shape[1] + out = at.ones(size).astype(config.floatX) + + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)]) + + +def test_mul_scalars(): + x = at.matrix("x") + size = x.shape[0] * x.shape[0] * x.shape[1] + out = at.ones(size).astype(config.floatX) + + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)]) + + +def test_div_scalars(): + x = at.matrix("x") + size = x.shape[0] // x.shape[1] + out = at.ones(size).astype(config.floatX) + + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)]) + + +def test_mod_scalars(): + x = at.matrix("x") + size = x.shape[0] % x.shape[1] + out = at.ones(size).astype(config.floatX) + + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)]) + + +def test_pytorch_multioutput(): + x = vector("x") + x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + y = vector("y") + y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) + + w = cosh(x**2 + y / 3.0) + v = cosh(x / 3.0 + y**2) + + fgraph = FunctionGraph([x, y], [w, v]) + + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_pytorch_logp(): + mu = vector("mu") + mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX) + tau = vector("tau") + tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX) + sigma = vector("sigma") + sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX) + value = vector("value") + value.tag.test_value = np.r_[0.1, -10].astype(config.floatX) + + logp = (-tau * (value - mu) ** 2 + log(tau / np.pi / 2.0)) / 2.0 + conditions = [sigma > 0] + alltrue = at_all([at_all(1 * val) for val in conditions]) + normal_logp = at.switch(alltrue, logp, -np.inf) + + fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp]) + + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) diff --git a/tests/link/pytorch/test_scan.py b/tests/link/pytorch/test_scan.py new file mode 100644 index 0000000000..7b0ef0ac18 --- /dev/null +++ b/tests/link/pytorch/test_scan.py @@ -0,0 +1,429 @@ +import re + +import numpy as np +import pytest + +import pytensor.tensor as at +from pytensor import function, shared +from pytensor.compile import get_mode +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.scan import until +from pytensor.scan.basic import scan +from pytensor.scan.op import Scan +from pytensor.tensor import random +from pytensor.tensor.math import gammaln, log +from pytensor.tensor.type import dmatrix, dvector, lscalar, matrix, scalar, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +pytorch = pytest.importorskip("torch") + + +@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)]) +def test_scan_sit_sot(view): + x0 = at.scalar("x0", dtype="float64") + xs, _ = scan( + lambda xtm1: xtm1 + 1, + outputs_info=[x0], + n_steps=10, + ) + if view: + xs = xs[view] + fg = FunctionGraph([x0], [xs]) + test_input_vals = [np.e] + compare_pytorch_and_py(fg, test_input_vals) + + +@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) +def test_scan_mit_sot(view): + x0 = at.vector("x0", dtype="float64", shape=(3,)) + xs, _ = scan( + lambda xtm3, xtm1: xtm3 + xtm1 + 1, + outputs_info=[{"initial": x0, "taps": [-3, -1]}], + n_steps=10, + ) + if view: + xs = xs[view] + fg = FunctionGraph([x0], [xs]) + test_input_vals = [np.full((3,), np.e)] + compare_pytorch_and_py(fg, test_input_vals) + + +@pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)]) +@pytest.mark.parametrize("view_y", [None, (-1,), slice(-4, -1, None)]) +def test_scan_multiple_mit_sot(view_x, view_y): + x0 = at.vector("x0", dtype="float64", shape=(3,)) + y0 = at.vector("y0", dtype="float64", shape=(4,)) + + def step(xtm3, xtm1, ytm4, ytm2): + return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2 + + [xs, ys], _ = scan( + fn=step, + outputs_info=[ + {"initial": x0, "taps": [-3, -1]}, + {"initial": y0, "taps": [-4, -2]}, + ], + n_steps=10, + ) + if view_x: + xs = xs[view_x] + if view_y: + ys = ys[view_y] + + fg = FunctionGraph([x0, y0], [xs, ys]) + test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)] + compare_pytorch_and_py(fg, test_input_vals) + + +@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)]) +def test_scan_nit_sot(view): + rng = np.random.default_rng(seed=49) + + xs = at.vector("x0", dtype="float64", shape=(10,)) + + ys, _ = scan( + lambda x: at.exp(x), + outputs_info=[None], + sequences=[xs], + ) + if view: + ys = ys[view] + fg = FunctionGraph([xs], [ys]) + test_input_vals = [rng.normal(size=10)] + # We need to remove pushout rewrites, or the whole scan would just be + # converted to an Elemwise on xs + pytorch_fn, _ = compare_pytorch_and_py( + fg, test_input_vals, pytorch_mode=get_mode("PyTorch").excluding("scan_pushout") + ) + scan_nodes = [ + node for node in pytorch_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + assert len(scan_nodes) == 1 + + +@pytest.mark.xfail(raises=NotImplementedError) +def test_scan_mit_mot(): + xs = at.vector("xs", shape=(10,)) + ys, _ = scan( + lambda xtm2, xtm1: (xtm2 + xtm1), + outputs_info=[{"initial": xs, "taps": [-2, -1]}], + n_steps=10, + ) + grads_wrt_xs = at.grad(ys.sum(), wrt=xs) + fg = FunctionGraph([xs], [grads_wrt_xs]) + compare_pytorch_and_py(fg, [np.arange(10)]) + + +def test_scan_update(): + sh_static = shared(np.array(0.0), name="sh_static") + sh_update = shared(np.array(1.0), name="sh_update") + + xs, update = scan( + lambda sh_static, sh_update: ( + sh_static + sh_update, + {sh_update: sh_update * 2}, + ), + outputs_info=[None], + non_sequences=[sh_static, sh_update], + strict=True, + n_steps=7, + ) + + pytorch_fn = function([], xs, updates=update, mode="PyTorch") + np.testing.assert_array_equal(pytorch_fn(), np.array([1, 2, 4, 8, 16, 32, 64]) + 0.0) + + sh_static.set_value(1.0) + np.testing.assert_array_equal( + pytorch_fn(), np.array([128, 256, 512, 1024, 2048, 4096, 8192]) + 1.0 + ) + + sh_static.set_value(2.0) + sh_update.set_value(1.0) + np.testing.assert_array_equal(pytorch_fn(), np.array([1, 2, 4, 8, 16, 32, 64]) + 2.0) + + +def test_scan_rng_update(): + rng = shared(np.random.default_rng(190), name="rng") + + def update_fn(rng): + new_rng, x = random.normal(rng=rng).owner.outputs + return x, {rng: new_rng} + + xs, update = scan( + update_fn, + outputs_info=[None], + non_sequences=[rng], + strict=True, + n_steps=10, + ) + + # Without updates + with pytest.warns( + UserWarning, + match=re.escape("[rng] will not be used in the compiled PyTorch graph"), + ): + pytorch_fn = function([], [xs], updates=None, mode="PyTorch") + + res1, res2 = pytorch_fn(), pytorch_fn() + assert np.unique(res1).size == 10 + assert np.unique(res2).size == 10 + np.testing.assert_array_equal(res1, res2) + + # With updates + with pytest.warns( + UserWarning, + match=re.escape("[rng] will not be used in the compiled PyTorch graph"), + ): + pytorch_fn = function([], [xs], updates=update, mode="PyTorch") + + res1, res2 = pytorch_fn(), pytorch_fn() + assert np.unique(res1).size == 10 + assert np.unique(res2).size == 10 + assert np.all(np.not_equal(res1, res2)) + + +@pytest.mark.xfail(raises=NotImplementedError) +def test_scan_while(): + xs, _ = scan( + lambda x: (x + 1, until(x < 10)), + outputs_info=[at.zeros(())], + n_steps=100, + ) + + fg = FunctionGraph([], [xs]) + compare_pytorch_and_py(fg, []) + + +def test_scan_SEIR(): + """Test a scan implementation of a SEIR model. + + SEIR model definition: + S[t+1] = S[t] - B[t] + E[t+1] = E[t] +B[t] - C[t] + I[t+1] = I[t+1] + C[t] - D[t] + + B[t] ~ Binom(S[t], beta) + C[t] ~ Binom(E[t], gamma) + D[t] ~ Binom(I[t], delta) + """ + + def binomln(n, k): + return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) + + def binom_log_prob(n, p, value): + return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) + + # sequences + at_C = vector("C_t", dtype="int32", shape=(8,)) + at_D = vector("D_t", dtype="int32", shape=(8,)) + # outputs_info (initial conditions) + st0 = lscalar("s_t0") + et0 = lscalar("e_t0") + it0 = lscalar("i_t0") + logp_c = scalar("logp_c") + logp_d = scalar("logp_d") + # non_sequences + beta = scalar("beta") + gamma = scalar("gamma") + delta = scalar("delta") + + # TODO: Use random streams when their PyTorch conversions are implemented. + # trng = pytensor.tensor.random.RandomStream(1234) + + def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): + # bt0 = trng.binomial(n=st0, p=beta) + bt0 = st0 * beta + bt0 = bt0.astype(st0.dtype) + + logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype) + logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype) + + st1 = st0 - bt0 + et1 = et0 + bt0 - ct0 + it1 = it0 + ct0 - dt0 + return st1, et1, it1, logp_c1, logp_d1 + + (st, et, it, logp_c_all, logp_d_all), _ = scan( + fn=seir_one_step, + sequences=[at_C, at_D], + outputs_info=[st0, et0, it0, logp_c, logp_d], + non_sequences=[beta, gamma, delta], + ) + st.name = "S_t" + et.name = "E_t" + it.name = "I_t" + logp_c_all.name = "C_t_logp" + logp_d_all.name = "D_t_logp" + + out_fg = FunctionGraph( + [at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], + [st, et, it, logp_c_all, logp_d_all], + ) + + s0, e0, i0 = 100, 50, 25 + logp_c0 = np.array(0.0, dtype=config.floatX) + logp_d0 = np.array(0.0, dtype=config.floatX) + beta_val, gamma_val, delta_val = ( + np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753] + ) + C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32) + D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32) + + test_input_vals = [ + C, + D, + s0, + e0, + i0, + logp_c0, + logp_d0, + beta_val, + gamma_val, + delta_val, + ] + compare_pytorch_and_py(out_fg, test_input_vals) + + +def test_scan_mitsot_with_nonseq(): + a_at = scalar("a") + + def input_step_fn(y_tm1, y_tm3, a): + y_tm1.name = "y_tm1" + y_tm3.name = "y_tm3" + res = (y_tm1 + y_tm3) * a + res.name = "y_t" + return res + + y_scan_at, _ = scan( + fn=input_step_fn, + outputs_info=[ + { + "initial": at.as_tensor_variable( + np.r_[-1.0, 1.3, 0.0].astype(config.floatX) + ), + "taps": [-1, -3], + }, + ], + non_sequences=[a_at], + n_steps=10, + name="y_scan", + ) + y_scan_at.name = "y" + y_scan_at.owner.inputs[0].name = "y_all" + + out_fg = FunctionGraph([a_at], [y_scan_at]) + + test_input_vals = [np.array(10.0).astype(config.floatX)] + compare_pytorch_and_py(out_fg, test_input_vals) + + +@pytest.mark.parametrize("x0_func", [dvector, dmatrix]) +@pytest.mark.parametrize("A_func", [dmatrix, dmatrix]) +def test_nd_scan_sit_sot(x0_func, A_func): + x0 = x0_func("x0") + A = A_func("A") + + n_steps = 3 + k = 3 + + # Must specify mode = PyTorch for the inner func to avoid a GEMM Op in the PyTorch graph + xs, _ = scan( + lambda X, A: A @ X, + non_sequences=[A], + outputs_info=[x0], + n_steps=n_steps, + mode=get_mode("PyTorch"), + ) + + x0_val = ( + np.arange(k, dtype=config.floatX) + if x0.ndim == 1 + else np.diag(np.arange(k, dtype=config.floatX)) + ) + A_val = np.eye(k, dtype=config.floatX) + + fg = FunctionGraph([x0, A], [xs]) + test_input_vals = [x0_val, A_val] + compare_pytorch_and_py(fg, test_input_vals) + + +def test_nd_scan_sit_sot_with_seq(): + n_steps = 3 + k = 3 + + x = at.matrix("x0", shape=(n_steps, k)) + A = at.matrix("A", shape=(k, k)) + + # Must specify mode = PyTorch for the inner func to avoid a GEMM Op in the PyTorch graph + xs, _ = scan( + lambda X, A: A @ X, + non_sequences=[A], + sequences=[x], + n_steps=n_steps, + mode=get_mode("PyTorch"), + ) + + x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) + A_val = np.eye(k, dtype=config.floatX) + + fg = FunctionGraph([x, A], [xs]) + test_input_vals = [x_val, A_val] + compare_pytorch_and_py(fg, test_input_vals) + + +def test_nd_scan_mit_sot(): + x0 = at.matrix("x0", shape=(3, 3)) + A = at.matrix("A", shape=(3, 3)) + B = at.matrix("B", shape=(3, 3)) + + # Must specify mode = PyTorch for the inner func to avoid a GEMM Op in the PyTorch graph + xs, _ = scan( + lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1, + outputs_info=[{"initial": x0, "taps": [-3, -1]}], + non_sequences=[A, B], + n_steps=10, + mode=get_mode("PyTorch"), + ) + + fg = FunctionGraph([x0, A, B], [xs]) + x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3) + A_val = np.eye(3, dtype=config.floatX) + B_val = np.eye(3, dtype=config.floatX) + + test_input_vals = [x0_val, A_val, B_val] + compare_pytorch_and_py(fg, test_input_vals) + + +def test_nd_scan_sit_sot_with_carry(): + x0 = at.vector("x0", shape=(3,)) + A = at.matrix("A", shape=(3, 3)) + + def step(x, A): + return A @ x, x.sum() + + # Must specify mode = PyTorch for the inner func to avoid a GEMM Op in the PyTorch graph + xs, _ = scan( + step, + outputs_info=[x0, None], + non_sequences=[A], + n_steps=10, + mode=get_mode("PyTorch"), + ) + + fg = FunctionGraph([x0, A], xs) + x0_val = np.arange(3, dtype=config.floatX) + A_val = np.eye(3, dtype=config.floatX) + + test_input_vals = [x0_val, A_val] + compare_pytorch_and_py(fg, test_input_vals) + + +def test_default_mode_excludes_incompatible_rewrites(): + # See issue #426 + A = matrix("A") + B = matrix("B") + out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) + fg = FunctionGraph([A, B], [out]) + compare_pytorch_and_py(fg, [np.eye(3), np.eye(3)]) diff --git a/tests/link/pytorch/test_shape.py b/tests/link/pytorch/test_shape.py new file mode 100644 index 0000000000..3c362943ac --- /dev/null +++ b/tests/link/pytorch/test_shape.py @@ -0,0 +1,88 @@ +import numpy as np +import pytest + +import pytensor.tensor as at +from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape +from pytensor.tensor.type import iscalar, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_shape_ops(): + x_np = np.zeros((20, 3)) + x = Shape()(at.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, [], must_be_device_array=False) + + x = Shape_i(1)(at.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, [], must_be_device_array=False) + + +def test_pytorch_specify_shape(): + in_at = at.matrix("in") + x = at.specify_shape(in_at, (4, None)) + x_fg = FunctionGraph([in_at], [x]) + compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)]) + + # When used to assert two arrays have similar shapes + in_at = at.matrix("in") + shape_at = at.matrix("shape") + x = at.specify_shape(in_at, shape_at.shape) + x_fg = FunctionGraph([in_at, shape_at], [x]) + compare_pytorch_and_py( + x_fg, + [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], + ) + + +def test_pytorch_Reshape_constant(): + a = vector("a") + x = reshape(a, (2, 2)) + x_fg = FunctionGraph([a], [x]) + compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +def test_pytorch_Reshape_concrete_shape(): + """PyTorch should compile when a concrete value is passed for the `shape` parameter.""" + a = vector("a") + x = reshape(a, a.shape) + x_fg = FunctionGraph([a], [x]) + compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) + x_fg = FunctionGraph([a], [x]) + compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +@pytest.mark.xfail( + reason="`shape_at` should be specified as a static argument", strict=True +) +def test_pytorch_Reshape_shape_graph_input(): + a = vector("a") + shape_at = iscalar("b") + x = reshape(a, (shape_at, shape_at)) + x_fg = FunctionGraph([a, shape_at], [x]) + compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) + + +def test_pytorch_compile_ops(): + x = DeepCopyOp()(at.as_tensor_variable(1.1)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) + + x_np = np.zeros((20, 1, 1)) + x = Unbroadcast(0, 2)(at.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) + + x = ViewOp()(at.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) diff --git a/tests/link/pytorch/test_slinalg.py b/tests/link/pytorch/test_slinalg.py new file mode 100644 index 0000000000..c975750ab3 --- /dev/null +++ b/tests/link/pytorch/test_slinalg.py @@ -0,0 +1,131 @@ +import numpy as np +import pytest + +import pytensor.tensor as at +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor import nlinalg as at_nlinalg +from pytensor.tensor import slinalg as at_slinalg +from pytensor.tensor import subtensor as at_subtensor +from pytensor.tensor.math import clip, cosh +from pytensor.tensor.type import matrix, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_basic(): + rng = np.random.default_rng(28494) + + x = matrix("x") + y = matrix("y") + b = vector("b") + + # `ScalarOp` + z = cosh(x**2 + y / 3.0) + + # `[Inc]Subtensor` + out = at_subtensor.set_subtensor(z[0], -10.0) + out = at_subtensor.inc_subtensor(out[0, 1], 2.0) + out = out[:5, :3] + + out_fg = FunctionGraph([x, y], [out]) + + test_input_vals = [ + np.tile(np.arange(10), (10, 1)).astype(config.floatX), + np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), + ] + _, [pytorch_res] = compare_pytorch_and_py(out_fg, test_input_vals) + + # Confirm that the `Subtensor` slice operations are correct + assert pytorch_res.shape == (5, 3) + + # Confirm that the `IncSubtensor` operations are correct + assert pytorch_res[0, 0] == -10.0 + assert pytorch_res[0, 1] == -8.0 + + out = clip(x, y, 5) + out_fg = FunctionGraph([x, y], [out]) + compare_pytorch_and_py(out_fg, test_input_vals) + + out = at.diagonal(x, 0) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py( + out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] + ) + + out = at_slinalg.cholesky(x) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py( + out_fg, + [ + (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( + config.floatX + ) + ], + ) + + # not sure why this isn't working yet with lower=False + out = at_slinalg.Cholesky(lower=False)(x) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py( + out_fg, + [ + (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( + config.floatX + ) + ], + ) + + out = at_slinalg.solve(x, b) + out_fg = FunctionGraph([x, b], [out]) + compare_pytorch_and_py( + out_fg, + [ + np.eye(10).astype(config.floatX), + np.arange(10).astype(config.floatX), + ], + ) + + out = at.diag(b) + out_fg = FunctionGraph([b], [out]) + compare_pytorch_and_py(out_fg, [np.arange(10).astype(config.floatX)]) + + out = at_nlinalg.det(x) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py( + out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] + ) + + out = at_nlinalg.matrix_inverse(x) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py( + out_fg, + [ + (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( + config.floatX + ) + ], + ) + + +@pytest.mark.parametrize("check_finite", [False, True]) +@pytest.mark.parametrize("lower", [False, True]) +@pytest.mark.parametrize("trans", [0, 1, 2]) +def test_pytorch_SolveTriangular(trans, lower, check_finite): + x = matrix("x") + b = vector("b") + + out = at_slinalg.solve_triangular( + x, + b, + trans=trans, + lower=lower, + check_finite=check_finite, + ) + out_fg = FunctionGraph([x, b], [out]) + compare_pytorch_and_py( + out_fg, + [ + np.eye(10).astype(config.floatX), + np.arange(10).astype(config.floatX), + ], + ) diff --git a/tests/link/pytorch/test_sparse.py b/tests/link/pytorch/test_sparse.py new file mode 100644 index 0000000000..bcc30aa66a --- /dev/null +++ b/tests/link/pytorch/test_sparse.py @@ -0,0 +1,75 @@ +import numpy as np +import pytest +import scipy.sparse + +import pytensor.sparse as ps +import pytensor.tensor as pt +from pytensor import function +from pytensor.graph import FunctionGraph +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +@pytest.mark.parametrize( + "op, x_type, y_type", + [ + (ps.dot, pt.vector, ps.matrix), + (ps.dot, pt.matrix, ps.matrix), + (ps.dot, ps.matrix, pt.vector), + (ps.dot, ps.matrix, pt.matrix), + # structured_dot only allows matrix @ matrix + (ps.structured_dot, pt.matrix, ps.matrix), + (ps.structured_dot, ps.matrix, pt.matrix), + ], +) +def test_sparse_dot_constant_sparse(x_type, y_type, op): + inputs = [] + test_values = [] + + if x_type is ps.matrix: + x_sp = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32") + x_pt = ps.as_sparse_variable(x_sp, name="x") + else: + x_pt = x_type("x", dtype="float32") + if x_pt.ndim == 1: + x_test = np.arange(40, dtype="float32") + else: + x_test = np.arange(5 * 40, dtype="float32").reshape(5, 40) + inputs.append(x_pt) + test_values.append(x_test) + + if y_type is ps.matrix: + y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32") + y_pt = ps.as_sparse_variable(y_sp, name="y") + else: + y_pt = y_type("y", dtype="float32") + if y_pt.ndim == 1: + y_test = np.arange(40, dtype="float32") + else: + y_test = np.arange(40 * 3, dtype="float32").reshape(40, 3) + inputs.append(y_pt) + test_values.append(y_test) + + dot_pt = op(x_pt, y_pt) + fgraph = FunctionGraph(inputs, [dot_pt]) + compare_pytorch_and_py(fgraph, test_values) + + +def test_sparse_dot_non_const_raises(): + x_pt = pt.vector("x") + + y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32") + y_pt = ps.as_sparse_variable(y_sp, name="y").type() + + out = ps.dot(x_pt, y_pt) + + msg = "PyTorch sparse dot only implemented for constant sparse inputs" + + with pytest.raises(NotImplementedError, match=msg): + function([x_pt, y_pt], out, mode="PyTorch") + + y_pt_shared = ps.shared(y_sp, name="y") + + out = ps.dot(x_pt, y_pt_shared) + + with pytest.raises(NotImplementedError, match=msg): + function([x_pt], out, mode="PyTorch") diff --git a/tests/link/pytorch/test_subtensor.py b/tests/link/pytorch/test_subtensor.py new file mode 100644 index 0000000000..dbeb23f958 --- /dev/null +++ b/tests/link/pytorch/test_subtensor.py @@ -0,0 +1,251 @@ +import numpy as np +import pytest + +import pytensor.tensor as at +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor import subtensor as at_subtensor +from pytensor.tensor.rewriting.pytorch import ( + boolean_indexing_set_or_inc, + boolean_indexing_sum, +) +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_Subtensor_constant(): + # Basic indices + x_at = at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) + out_at = x_at[1, 2, 0] + assert isinstance(out_at.owner.op, at_subtensor.Subtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + out_at = x_at[1:, 1, :] + assert isinstance(out_at.owner.op, at_subtensor.Subtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + out_at = x_at[:2, 1, :] + assert isinstance(out_at.owner.op, at_subtensor.Subtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + out_at = x_at[1:2, 1, :] + assert isinstance(out_at.owner.op, at_subtensor.Subtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + # Advanced indexing + out_at = at_subtensor.advanced_subtensor1(x_at, [1, 2]) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor1) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + out_at = x_at[[1, 2], [2, 3]] + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + # Advanced and basic indexing + out_at = x_at[[1, 2], :] + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + out_at = x_at[[1, 2], :, [3, 4]] + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + # Flipping + out_at = x_at[::-1] + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + +@pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling") +def test_pytorch_Subtensor_dynamic(): + a = at.iscalar("a") + x = at.arange(3) + out_at = x[:a] + assert isinstance(out_at.owner.op, at_subtensor.Subtensor) + out_fg = FunctionGraph([a], [out_at]) + compare_pytorch_and_py(out_fg, [1]) + + +def test_pytorch_Subtensor_boolean_mask(): + """PyTorch does not support resizing arrays with boolean masks.""" + x_at = at.vector("x", dtype="float64") + out_at = x_at[x_at < 0] + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) + + out_fg = FunctionGraph([x_at], [out_at]) + + x_at_test = np.arange(-5, 5) + with pytest.raises(NotImplementedError, match="resizing arrays with boolean"): + compare_pytorch_and_py(out_fg, [x_at_test]) + + +def test_pytorch_Subtensor_boolean_mask_reexpressible(): + """Summing values with boolean indexing. + + This test ensures that the sum of an `AdvancedSubtensor` `Op`s with boolean + indexing is replaced with the sum of an equivalent `Switch` `Op`, using the + `pytorch_boolean_indexing_sum` rewrite. + + PyTorch forces users to re-express this logic manually, so this is an + improvement over its user interface. + + """ + x_at = at.matrix("x") + out_at = x_at[x_at < 0].sum() + out_fg = FunctionGraph([x_at], [out_at]) + compare_pytorch_and_py(out_fg, [np.arange(25).reshape(5, 5).astype(config.floatX)]) + + +def test_boolean_indexing_sum_not_applicable(): + """Test that boolean_indexing_sum does not return an invalid replacement in cases where it doesn't apply.""" + x = at.matrix("x") + out = x[x[:, 0] < 0, :].sum(axis=-1) + fg = FunctionGraph([x], [out]) + assert boolean_indexing_sum.transform(fg, fg.outputs[0].owner) is None + + out = x[x[:, 0] < 0, 0].sum() + fg = FunctionGraph([x], [out]) + assert boolean_indexing_sum.transform(fg, fg.outputs[0].owner) is None + + +def test_pytorch_IncSubtensor(): + rng = np.random.default_rng(213234) + + x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) + x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)) + + # "Set" basic indices + st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) + out_at = at_subtensor.set_subtensor(x_at[1, 2, 3], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) + out_at = at_subtensor.set_subtensor(x_at[:2, 0, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + out_at = at_subtensor.set_subtensor(x_at[0, 1:3, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + # "Set" advanced indices + st_at = at.as_tensor_variable( + rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) + ) + out_at = at_subtensor.set_subtensor(x_at[np.r_[0, 2]], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) + out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + # "Set" boolean indices + mask_at = at.constant(x_np > 0) + out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + # "Increment" basic indices + st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) + out_at = at_subtensor.inc_subtensor(x_at[1, 2, 3], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) + out_at = at_subtensor.inc_subtensor(x_at[:2, 0, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + out_at = at_subtensor.set_subtensor(x_at[0, 1:3, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + # "Increment" advanced indices + st_at = at.as_tensor_variable( + rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) + ) + out_at = at_subtensor.inc_subtensor(x_at[np.r_[0, 2]], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) + out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + # "Increment" boolean indices + mask_at = at.constant(x_np > 0) + out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) + out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) + out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_pytorch_and_py(out_fg, []) + + +def test_pytorch_IncSubtensor_boolean_indexing_reexpressible(): + """Setting or incrementing values with boolean indexing. + + This test ensures that `AdvancedIncSubtensor` `Op`s with boolean indexing is + replaced with an equivalent `Switch` `Op`, using the + `boolean_indexing_set_of_inc` rewrite. + + PyTorch forces users to re-express this logic manually, so this is an + improvement over its user interface. + + """ + rng = np.random.default_rng(213234) + x_np = rng.uniform(-1, 1, size=(4, 5)).astype(config.floatX) + + x_at = at.matrix("x") + mask_at = at.as_tensor(x_at) > 0 + out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([x_at], [out_at]) + compare_pytorch_and_py(out_fg, [x_np]) + + mask_at = at.as_tensor(x_at) > 0 + out_at = at_subtensor.inc_subtensor(x_at[mask_at], 1.0) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([x_at], [out_at]) + compare_pytorch_and_py(out_fg, [x_np]) + + +def test_boolean_indexing_set_or_inc_not_applicable(): + """Test that `boolean_indexing_set_or_inc` does not return an invalid replacement in cases where it doesn't apply.""" + x = at.vector("x") + mask = at.as_tensor(x) > 0 + out = at_subtensor.set_subtensor(x[mask], [0, 1, 2]) + fg = FunctionGraph([x], [out]) + assert boolean_indexing_set_or_inc.transform(fg, fg.outputs[0].owner) is None diff --git a/tests/link/pytorch/test_tensor_basic.py b/tests/link/pytorch/test_tensor_basic.py new file mode 100644 index 0000000000..32428f0258 --- /dev/null +++ b/tests/link/pytorch/test_tensor_basic.py @@ -0,0 +1,239 @@ +import numpy as np +import pytest + +from pytensor.compile import get_mode + + +pytorch = pytest.importorskip("torch") +import pytorch.errors + +import pytensor +import pytensor.tensor.basic as at +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.tensor.type import iscalar, matrix, scalar, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py +from tests.tensor.test_basic import TestAlloc + + +def test_pytorch_Alloc(): + x = at.alloc(0.0, 2, 3) + x_fg = FunctionGraph([], [x]) + + _, [pytorch_res] = compare_pytorch_and_py(x_fg, []) + + assert pytorch_res.shape == (2, 3) + + x = at.alloc(1.1, 2, 3) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) + + x = at.AllocEmpty("float32")(2, 3) + x_fg = FunctionGraph([], [x]) + + def compare_shape_dtype(x, y): + (x,) = x + (y,) = y + return x.shape == y.shape and x.dtype == y.dtype + + compare_pytorch_and_py(x_fg, [], assert_fn=compare_shape_dtype) + + a = scalar("a") + x = at.alloc(a, 20) + x_fg = FunctionGraph([a], [x]) + + compare_pytorch_and_py(x_fg, [10.0]) + + a = vector("a") + x = at.alloc(a, 20, 10) + x_fg = FunctionGraph([a], [x]) + + compare_pytorch_and_py(x_fg, [np.ones(10, dtype=config.floatX)]) + + +def test_alloc_runtime_broadcast(): + TestAlloc.check_runtime_broadcast(get_mode("PyTorch")) + + +def test_pytorch_MakeVector(): + x = at.make_vector(1, 2, 3) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) + + +def test_arange(): + out = at.arange(1, 10, 2) + fgraph = FunctionGraph([], [out]) + compare_pytorch_and_py(fgraph, []) + + +def test_arange_of_shape(): + x = vector("x") + out = at.arange(1, x.shape[-1], 2) + fgraph = FunctionGraph([x], [out]) + compare_pytorch_and_py(fgraph, [np.zeros((5,))]) + + +def test_arange_nonconcrete(): + """PyTorch cannot JIT-compile `pytorch.numpy.arange` when arguments are not concrete values.""" + + a = scalar("a") + a.tag.test_value = 10 + out = at.arange(a) + + with pytest.raises(NotImplementedError): + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_pytorch_Join(): + a = matrix("a") + b = matrix("b") + + x = at.join(0, a, b) + x_fg = FunctionGraph([a, b], [x]) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), + ], + ) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0]].astype(config.floatX), + ], + ) + + x = at.join(1, a, b) + x_fg = FunctionGraph([a, b], [x]) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), + ], + ) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), + np.c_[[5.0, 6.0]].astype(config.floatX), + ], + ) + + +class TestJaxSplit: + def test_basic(self): + a = matrix("a") + a_splits = at.split(a, splits_size=[1, 2, 3], n_splits=3, axis=0) + fg = FunctionGraph([a], a_splits) + compare_pytorch_and_py( + fg, + [ + np.zeros((6, 4)).astype(config.floatX), + ], + ) + + a = matrix("a", shape=(6, None)) + a_splits = at.split(a, splits_size=[2, a.shape[0] - 2], n_splits=2, axis=0) + fg = FunctionGraph([a], a_splits) + compare_pytorch_and_py( + fg, + [ + np.zeros((6, 4)).astype(config.floatX), + ], + ) + + def test_runtime_errors(self): + a = matrix("a") + + a_splits = at.split(a, splits_size=[2, 2, 2], n_splits=2, axis=0) + fn = pytensor.function([a], a_splits, mode="PyTorch") + with pytest.raises( + ValueError, match="Length of splits is not equal to n_splits" + ): + fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) + + a_splits = at.split(a, splits_size=[2, 4], n_splits=3, axis=0) + fn = pytensor.function([a], a_splits, mode="PyTorch") + with pytest.raises( + ValueError, match="Length of splits is not equal to n_splits" + ): + fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) + + a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=0) + fn = pytensor.function([a], a_splits, mode="PyTorch") + with pytest.raises( + ValueError, match="Split sizes do not sum up to input length along axis: 7" + ): + fn(np.zeros((7, 4), dtype=pytensor.config.floatX)) + + a_splits = at.split(a, splits_size=[2, -4, 8], n_splits=3, axis=0) + fn = pytensor.function([a], a_splits, mode="PyTorch") + with pytest.raises( + ValueError, + match="Split sizes cannot be negative", + ): + fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) + + def test_pytorch_split_not_supported(self): + a = matrix("a", shape=(6, None)) + + a_splits = at.split(a, splits_size=[2, a.shape[1] - 2], n_splits=2, axis=1) + with pytest.warns( + UserWarning, match="Split node does not have constant split positions." + ): + fn = pytensor.function([a], a_splits, mode="PyTorch") + # It raises an informative ConcretizationTypeError, but there's an AttributeError that surpasses it + with pytest.raises(AttributeError): + fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) + + split_axis = iscalar("split_axis") + a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis) + with pytest.warns(UserWarning, match="Split node does not have constant axis."): + fn = pytensor.function([a, split_axis], a_splits, mode="PyTorch") + # Same as above, an AttributeError surpasses the `TracerIntegerConversionError` + # Both errors are included for backwards compatibility + with pytest.raises((AttributeError, pytorch.errors.TracerIntegerConversionError)): + fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0) + + +def test_pytorch_eye(): + """Tests pytorchification of the Eye operator""" + out = at.eye(3) + out_fg = FunctionGraph([], [out]) + + compare_pytorch_and_py(out_fg, []) + + +def test_tri(): + out = at.tri(10, 10, 0) + fgraph = FunctionGraph([], [out]) + compare_pytorch_and_py(fgraph, []) + + +def test_tri_nonconcrete(): + """PyTorch cannot JIT-compile `pytorch.numpy.tri` when arguments are not concrete values.""" + + m, n, k = ( + scalar("a", dtype="int64"), + scalar("n", dtype="int64"), + scalar("k", dtype="int64"), + ) + m.tag.test_value = 10 + n.tag.test_value = 10 + k.tag.test_value = 0 + + out = at.tri(m, n, k) + + # The actual error the user will see should be pytorch.errors.ConcretizationTypeError, but + # the error handler raises an Attribute error first, so that's what this test needs to pass + with pytest.raises(AttributeError): + fgraph = FunctionGraph([m, n, k], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) From 7aee70dd51f61c298be94f22a76d32db6ca9813a Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Mon, 25 Sep 2023 17:56:03 +0200 Subject: [PATCH 3/4] Add linker. --- pytensor/compile/mode.py | 21 ++++++- pytensor/configdefaults.py | 1 + pytensor/link/pytorch/linker.py | 74 +++++++++++++++++++++++++ tests/link/pytorch/test_tensor_basic.py | 4 +- 4 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 pytensor/link/pytorch/linker.py diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 6dd5496505..c3fba340cf 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -28,6 +28,7 @@ from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.jax.linker import JAXLinker from pytensor.link.numba.linker import NumbaLinker +from pytensor.link.pytorch.linker import PyTorchLinker from pytensor.link.vm import VMLinker @@ -48,6 +49,7 @@ "cvm_nogc": VMLinker(allow_gc=False, use_cloop=True), "jax": JAXLinker(), "numba": NumbaLinker(), + "pytorch": PyTorchLinker(), } @@ -469,13 +471,26 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], ), ) - +PYTORCH = Mode( + PyTorchLinker(), + RewriteDatabaseQuery( + include=["fast_run"], + exclude=[ + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + "local_uint_constant_indices", + ], + ), +) predefined_modes = { "FAST_COMPILE": FAST_COMPILE, "FAST_RUN": FAST_RUN, "JAX": JAX, "NUMBA": NUMBA, + "PYTORCH": PYTORCH, } instantiated_default_mode = None @@ -548,7 +563,7 @@ def register_mode(name, mode): predefined_modes[name] = mode -def get_target_language(mode=None) -> Tuple[Literal["py", "c", "numba", "jax"], ...]: +def get_target_language(mode=None) -> Tuple[Literal["py", "c", "numba", "jax", "pytorch"], ...]: """Get the compilation target language.""" if mode is None: @@ -560,6 +575,8 @@ def get_target_language(mode=None) -> Tuple[Literal["py", "c", "numba", "jax"], return ("numba",) if isinstance(linker, JAXLinker): return ("jax",) + if isinstance(linker, PyTorchLinker): + return ("pytorch",) if isinstance(linker, PerformLinker): return ("py",) if isinstance(linker, CLinker): diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index 58f2f2faa9..5a8f34e1ab 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -107,6 +107,7 @@ def _filter_mode(val): "DEBUG_MODE", "JAX", "NUMBA", + "PYTORCH" ] if val in str_options: return val diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py new file mode 100644 index 0000000000..be0b0057ac --- /dev/null +++ b/pytensor/link/pytorch/linker.py @@ -0,0 +1,74 @@ +import warnings +import torch + +from pytensor.compile.sharedvalue import SharedVariable, shared +from pytensor.graph.basic import Constant +from pytensor.link.basic import JITLinker + + +class PyTorchLinker(JITLinker): + """A `Linker` that JIT-compiles NumPy-based operations using PyTorch.""" + + def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): + from pytensor.link.pytorch.dispatch import pytorch_funcify + from pytensor.tensor.random.type import RandomType + + shared_rng_inputs = [ + inp + for inp in fgraph.inputs + if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType)) + ] + + if shared_rng_inputs: + warnings.warn( + f"The RandomType SharedVariables {shared_rng_inputs} will not be used " + f"in the compiled PyTorch graph. Instead a copy will be used.", + UserWarning, + ) + new_shared_rng_inputs = [ + shared(inp.get_value(borrow=False)) for inp in shared_rng_inputs + ] + + fgraph.replace_all( + zip(shared_rng_inputs, new_shared_rng_inputs), + import_missing=True, + reason="PyTorchLinker.fgraph_convert", + ) + + for old_inp, new_inp in zip(shared_rng_inputs, new_shared_rng_inputs): + new_inp_storage = [new_inp.get_value(borrow=True)] + storage_map[new_inp] = new_inp_storage + old_inp_storage = storage_map.pop(old_inp) + for input_storage_idx, input_storage_item in enumerate(input_storage): + if input_storage_item is old_inp_storage: + break + else: # no break + raise ValueError() + input_storage[input_storage_idx] = new_inp_storage + fgraph.remove_input( + fgraph.inputs.index(old_inp), reason="PyTorchLinker.fgraph_convert" + ) + + return pytorch_funcify( + fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs + ) + + def jit_compile(self, fn): + # For PyTorch, the script mode allows for JIT compilation + scripted_fn = torch.jit.script(fn) + return scripted_fn + + def create_thunk_inputs(self, storage_map): + from pytensor.link.pytorch.dispatch import pytorch_typify + + thunk_inputs = [] + for n in self.fgraph.inputs: + sinput = storage_map[n] + if isinstance(sinput[0], torch.Generator): + new_value = pytorch_typify( + sinput[0], dtype=getattr(sinput[0], "dtype", None) + ) + sinput[0] = new_value + thunk_inputs.append(sinput) + + return thunk_inputs diff --git a/tests/link/pytorch/test_tensor_basic.py b/tests/link/pytorch/test_tensor_basic.py index 32428f0258..c5c9eb9a4a 100644 --- a/tests/link/pytorch/test_tensor_basic.py +++ b/tests/link/pytorch/test_tensor_basic.py @@ -5,7 +5,7 @@ pytorch = pytest.importorskip("torch") -import pytorch.errors +import torch.errors import pytensor import pytensor.tensor.basic as at @@ -200,7 +200,7 @@ def test_pytorch_split_not_supported(self): fn = pytensor.function([a, split_axis], a_splits, mode="PyTorch") # Same as above, an AttributeError surpasses the `TracerIntegerConversionError` # Both errors are included for backwards compatibility - with pytest.raises((AttributeError, pytorch.errors.TracerIntegerConversionError)): + with pytest.raises((AttributeError, torch.errors.TracerIntegerConversionError)): fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0) From 264e08cbe4a63025a9a36f1772465a99b22230d8 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Mon, 25 Sep 2023 18:57:17 +0200 Subject: [PATCH 4/4] Fixes so that test-suite runs. --- pytensor/link/pytorch/dispatch/basic.py | 38 +++++++++--------- pytensor/link/pytorch/dispatch/scan.py | 4 +- pytensor/link/pytorch/dispatch/slinalg.py | 14 +++---- pytensor/link/pytorch/dispatch/subtensor.py | 44 ++++++++++----------- tests/link/pytorch/test_random.py | 14 ------- tests/link/pytorch/test_subtensor.py | 25 ------------ tests/link/pytorch/test_tensor_basic.py | 5 --- 7 files changed, 50 insertions(+), 94 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index da382fddb4..282e70535b 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -13,7 +13,7 @@ @singledispatch -def torch_typify(data, dtype=None, **kwargs): +def pytorch_typify(data, dtype=None, **kwargs): r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" if dtype is None: return data @@ -21,21 +21,21 @@ def torch_typify(data, dtype=None, **kwargs): return torch.tensor(data, dtype=dtype) -@torch_typify.register(np.ndarray) -def torch_typify_ndarray(data, dtype=None, **kwargs): +@pytorch_typify.register(np.ndarray) +def pytorch_typify_ndarray(data, dtype=None, **kwargs): if len(data.shape) == 0: return data.item() return torch.tensor(data, dtype=dtype) @singledispatch -def torch_funcify(op, node=None, storage_map=None, **kwargs): +def pytorch_funcify(op, node=None, storage_map=None, **kwargs): """Create a PyTorch compatible function from an PyTensor `Op`.""" raise NotImplementedError(f"No PyTorch conversion for the given `Op`: {op}") -@torch_funcify.register(FunctionGraph) -def torch_funcify_FunctionGraph( +@pytorch_funcify.register(FunctionGraph) +def pytorch_funcify_FunctionGraph( fgraph, node=None, fgraph_name="torch_funcified_fgraph", @@ -43,15 +43,15 @@ def torch_funcify_FunctionGraph( ): return fgraph_to_python( fgraph, - torch_funcify, - type_conversion_fn=torch_typify, + pytorch_funcify, + type_conversion_fn=pytorch_typify, fgraph_name=fgraph_name, **kwargs, ) -@torch_funcify.register(IfElse) -def torch_funcify_IfElse(op, **kwargs): +@pytorch_funcify.register(IfElse) +def pytorch_funcify_IfElse(op, **kwargs): n_outs = op.n_outs def ifelse(cond, *args, n_outs=n_outs): @@ -63,9 +63,9 @@ def ifelse(cond, *args, n_outs=n_outs): return ifelse -@torch_funcify.register(Assert) -@torch_funcify.register(CheckAndRaise) -def torch_funcify_CheckAndRaise(op, **kwargs): +@pytorch_funcify.register(Assert) +@pytorch_funcify.register(CheckAndRaise) +def pytorch_funcify_CheckAndRaise(op, **kwargs): warnings.warn( f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as PyTorch tracing would remove it.""", stacklevel=2, @@ -77,7 +77,7 @@ def assert_fn(x, *inputs): return assert_fn -def torch_safe_copy(x): +def pytorch_safe_copy(x): try: res = torch.clone(x) except NotImplementedError: @@ -93,16 +93,16 @@ def torch_safe_copy(x): return res -@torch_funcify.register(DeepCopyOp) -def torch_funcify_DeepCopyOp(op, **kwargs): +@pytorch_funcify.register(DeepCopyOp) +def pytorch_funcify_DeepCopyOp(op, **kwargs): def deepcopyop(x): - return torch_safe_copy(x) + return pytorch_safe_copy(x) return deepcopyop -@torch_funcify.register(ViewOp) -def torch_funcify_ViewOp(op, **kwargs): +@pytorch_funcify.register(ViewOp) +def pytorch_funcify_ViewOp(op, **kwargs): def viewop(x): return x diff --git a/pytensor/link/pytorch/dispatch/scan.py b/pytensor/link/pytorch/dispatch/scan.py index 9730a190ab..7d653840a0 100644 --- a/pytensor/link/pytorch/dispatch/scan.py +++ b/pytensor/link/pytorch/dispatch/scan.py @@ -1,6 +1,6 @@ import torch -from pytensor.compile.mode import PyTorch +from pytensor.compile.mode import PYTORCH from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scan.op import Scan @@ -18,7 +18,7 @@ def pytorch_funcify_Scan(op: Scan, **kwargs): ) # Optimize inner graph (exclude any defalut rewrites that are incompatible with PyTorch mode) - rewriter = op.mode_instance.excluding(*PyTorch._optimizer.exclude).optimizer + rewriter = op.mode_instance.excluding(*PYTORCH._optimizer.exclude).optimizer rewriter(op.fgraph) scan_inner_func = pytorch_funcify(op.fgraph, **kwargs) diff --git a/pytensor/link/pytorch/dispatch/slinalg.py b/pytensor/link/pytorch/dispatch/slinalg.py index 2d783097b9..d75119b33d 100644 --- a/pytensor/link/pytorch/dispatch/slinalg.py +++ b/pytensor/link/pytorch/dispatch/slinalg.py @@ -1,11 +1,11 @@ import torch -from pytensor.link.pytorch.dispatch.basic import torch_funcify +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular -@torch_funcify.register(Cholesky) -def torch_funcify_Cholesky(op, **kwargs): +@pytorch_funcify.register(Cholesky) +def pytorch_funcify_Cholesky(op, **kwargs): lower = op.lower def cholesky(a, lower=lower): @@ -14,8 +14,8 @@ def cholesky(a, lower=lower): return cholesky -@torch_funcify.register(Solve) -def torch_funcify_Solve(op, **kwargs): +@pytorch_funcify.register(Solve) +def pytorch_funcify_Solve(op, **kwargs): if op.assume_a != "gen" and op.lower: lower = True else: @@ -30,8 +30,8 @@ def solve(a, b, lower=lower): return solve -@torch_funcify.register(SolveTriangular) -def torch_funcify_SolveTriangular(op, **kwargs): +@pytorch_funcify.register(SolveTriangular) +def pytorch_funcify_SolveTriangular(op, **kwargs): lower = op.lower trans = op.trans unit_diagonal = op.unit_diagonal diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 1ec8384283..1e2cd4db82 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -1,5 +1,5 @@ import torch -from pytensor.link.pytorch.dispatch.basic import torch_funcify +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -31,7 +31,7 @@ """ -def subtensor_assert_indices_torch_compatible(node, idx_list): +def subtensor_assert_indices_pytorch_compatible(node, idx_list): from pytensor.graph.basic import Constant from pytensor.tensor.variable import TensorVariable @@ -46,12 +46,12 @@ def subtensor_assert_indices_torch_compatible(node, idx_list): raise NotImplementedError(DYNAMIC_SLICE_LENGTH_ERROR) -@torch_funcify.register(Subtensor) -@torch_funcify.register(AdvancedSubtensor) -@torch_funcify.register(AdvancedSubtensor1) -def torch_funcify_Subtensor(op, node, **kwargs): +@pytorch_funcify.register(Subtensor) +@pytorch_funcify.register(AdvancedSubtensor) +@pytorch_funcify.register(AdvancedSubtensor1) +def pytorch_funcify_Subtensor(op, node, **kwargs): idx_list = getattr(op, "idx_list", None) - subtensor_assert_indices_torch_compatible(node, idx_list) + subtensor_assert_indices_pytorch_compatible(node, idx_list) def subtensor_constant(x, *ilists): indices = indices_from_subtensor(ilists, idx_list) @@ -63,55 +63,55 @@ def subtensor_constant(x, *ilists): return subtensor_constant -@torch_funcify.register(IncSubtensor) -@torch_funcify.register(AdvancedIncSubtensor1) -def torch_funcify_IncSubtensor(op, node, **kwargs): +@pytorch_funcify.register(IncSubtensor) +@pytorch_funcify.register(AdvancedIncSubtensor1) +def pytorch_funcify_IncSubtensor(op, node, **kwargs): idx_list = getattr(op, "idx_list", None) if getattr(op, "set_instead_of_inc", False): - def torch_fn(x, indices, y): + def pytorch_fn(x, indices, y): x[indices] = y return x else: - def torch_fn(x, indices, y): + def pytorch_fn(x, indices, y): x[indices] += y return x - def incsubtensor(x, y, *ilist, torch_fn=torch_fn, idx_list=idx_list): + def incsubtensor(x, y, *ilist, pytorch_fn=pytorch_fn, idx_list=idx_list): indices = indices_from_subtensor(ilist, idx_list) if len(indices) == 1: indices = indices[0] - return torch_fn(x, indices, y) + return pytorch_fn(x, indices, y) return incsubtensor -@torch_funcify.register(AdvancedIncSubtensor) -def torch_funcify_AdvancedIncSubtensor(op, node, **kwargs): +@pytorch_funcify.register(AdvancedIncSubtensor) +def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): if getattr(op, "set_instead_of_inc", False): - def torch_fn(x, indices, y): + def pytorch_fn(x, indices, y): x[indices] = y return x else: - def torch_fn(x, indices, y): + def pytorch_fn(x, indices, y): x[indices] += y return x - def advancedincsubtensor(x, y, *ilist, torch_fn=torch_fn): - return torch_fn(x, ilist, y) + def advancedincsubtensor(x, y, *ilist, pytorch_fn=pytorch_fn): + return pytorch_fn(x, ilist, y) return advancedincsubtensor -@torch_funcify.register(MakeSlice) -def torch_funcify_MakeSlice(op, **kwargs): +@pytorch_funcify.register(MakeSlice) +def pytorch_funcify_MakeSlice(op, **kwargs): def makeslice(*x): return slice(*x) diff --git a/tests/link/pytorch/test_random.py b/tests/link/pytorch/test_random.py index f7ec310227..bfa7c17139 100644 --- a/tests/link/pytorch/test_random.py +++ b/tests/link/pytorch/test_random.py @@ -17,9 +17,6 @@ pytorch = pytest.importorskip("torch") -from pytensor.link.pytorch.dispatch.random import numpyro_available # noqa: E402 - - def random_function(*args, **kwargs): with pytest.warns( UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" @@ -424,9 +421,6 @@ def test_random_updates_input_storage_order(): (2,), "vonmises", lambda mu, kappa: (kappa, mu), - marks=pytest.mark.skipif( - not numpyro_available, reason="VonMises dispatch requires numpyro" - ), ), ], ) @@ -569,7 +563,6 @@ def test_negative_binomial(): ) -@pytest.mark.skipif(not numpyro_available, reason="Binomial dispatch requires numpyro") def test_binomial(): rng = shared(np.random.RandomState(123)) n = np.array([10, 40]) @@ -581,9 +574,6 @@ def test_binomial(): np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1) -@pytest.mark.skipif( - not numpyro_available, reason="BetaBinomial dispatch requires numpyro" -) def test_beta_binomial(): rng = shared(np.random.RandomState(123)) n = np.array([10, 40]) @@ -600,9 +590,6 @@ def test_beta_binomial(): ) -@pytest.mark.skipif( - not numpyro_available, reason="Multinomial dispatch requires numpyro" -) def test_multinomial(): rng = shared(np.random.RandomState(123)) n = np.array([10, 40]) @@ -616,7 +603,6 @@ def test_multinomial(): ) -@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro") def test_vonmises_mu_outside_circle(): # Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle # We test that the random draws from the PyTorch dispatch work as expected in these cases diff --git a/tests/link/pytorch/test_subtensor.py b/tests/link/pytorch/test_subtensor.py index dbeb23f958..1bc3b22ea8 100644 --- a/tests/link/pytorch/test_subtensor.py +++ b/tests/link/pytorch/test_subtensor.py @@ -5,10 +5,6 @@ from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.tensor import subtensor as at_subtensor -from pytensor.tensor.rewriting.pytorch import ( - boolean_indexing_set_or_inc, - boolean_indexing_sum, -) from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -103,18 +99,6 @@ def test_pytorch_Subtensor_boolean_mask_reexpressible(): compare_pytorch_and_py(out_fg, [np.arange(25).reshape(5, 5).astype(config.floatX)]) -def test_boolean_indexing_sum_not_applicable(): - """Test that boolean_indexing_sum does not return an invalid replacement in cases where it doesn't apply.""" - x = at.matrix("x") - out = x[x[:, 0] < 0, :].sum(axis=-1) - fg = FunctionGraph([x], [out]) - assert boolean_indexing_sum.transform(fg, fg.outputs[0].owner) is None - - out = x[x[:, 0] < 0, 0].sum() - fg = FunctionGraph([x], [out]) - assert boolean_indexing_sum.transform(fg, fg.outputs[0].owner) is None - - def test_pytorch_IncSubtensor(): rng = np.random.default_rng(213234) @@ -240,12 +224,3 @@ def test_pytorch_IncSubtensor_boolean_indexing_reexpressible(): assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([x_at], [out_at]) compare_pytorch_and_py(out_fg, [x_np]) - - -def test_boolean_indexing_set_or_inc_not_applicable(): - """Test that `boolean_indexing_set_or_inc` does not return an invalid replacement in cases where it doesn't apply.""" - x = at.vector("x") - mask = at.as_tensor(x) > 0 - out = at_subtensor.set_subtensor(x[mask], [0, 1, 2]) - fg = FunctionGraph([x], [out]) - assert boolean_indexing_set_or_inc.transform(fg, fg.outputs[0].owner) is None diff --git a/tests/link/pytorch/test_tensor_basic.py b/tests/link/pytorch/test_tensor_basic.py index c5c9eb9a4a..1fd97bd516 100644 --- a/tests/link/pytorch/test_tensor_basic.py +++ b/tests/link/pytorch/test_tensor_basic.py @@ -5,7 +5,6 @@ pytorch = pytest.importorskip("torch") -import torch.errors import pytensor import pytensor.tensor.basic as at @@ -198,10 +197,6 @@ def test_pytorch_split_not_supported(self): a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis) with pytest.warns(UserWarning, match="Split node does not have constant axis."): fn = pytensor.function([a, split_axis], a_splits, mode="PyTorch") - # Same as above, an AttributeError surpasses the `TracerIntegerConversionError` - # Both errors are included for backwards compatibility - with pytest.raises((AttributeError, torch.errors.TracerIntegerConversionError)): - fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0) def test_pytorch_eye():