From 27e25265a31da525287894090ece90e622426d37 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Mon, 13 May 2024 16:01:02 +0530 Subject: [PATCH 01/45] Add pytorch support for some basic Ops --- pytensor/compile/mode.py | 16 + pytensor/link/pytorch/dispatch/__init__.py | 18 ++ pytensor/link/pytorch/dispatch/basic.py | 117 ++++++++ pytensor/link/pytorch/dispatch/elemwise.py | 71 +++++ pytensor/link/pytorch/dispatch/scalar.py | 332 +++++++++++++++++++++ pytensor/link/pytorch/linker.py | 88 ++++++ pytorch_sandbox_example.ipynb | 214 +++++++++++++ 7 files changed, 856 insertions(+) create mode 100644 pytensor/link/pytorch/dispatch/__init__.py create mode 100644 pytensor/link/pytorch/dispatch/basic.py create mode 100644 pytensor/link/pytorch/dispatch/elemwise.py create mode 100644 pytensor/link/pytorch/dispatch/scalar.py create mode 100644 pytensor/link/pytorch/linker.py create mode 100644 pytorch_sandbox_example.ipynb diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 2e3d8f456e..86bff7b478 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -28,6 +28,7 @@ from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.jax.linker import JAXLinker from pytensor.link.numba.linker import NumbaLinker +from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.vm import VMLinker @@ -47,6 +48,7 @@ "vm_nogc": VMLinker(allow_gc=False, use_cloop=False), "cvm_nogc": VMLinker(allow_gc=False, use_cloop=True), "jax": JAXLinker(), + "pytorch": PytorchLinker(), "numba": NumbaLinker(), } @@ -462,6 +464,19 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ], ), ) +PYTORCH = Mode( + PytorchLinker(), + RewriteDatabaseQuery( + include=["fast_run"], + exclude=[ + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + "local_uint_constant_indices", + ], + ), +) NUMBA = Mode( NumbaLinker(), RewriteDatabaseQuery( @@ -476,6 +491,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "FAST_RUN": FAST_RUN, "JAX": JAX, "NUMBA": NUMBA, + "PYTORCH": PYTORCH, } instantiated_default_mode = None diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py new file mode 100644 index 0000000000..b8a4547dc5 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -0,0 +1,18 @@ +# isort: off +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify + +# # Load dispatch specializations +import pytensor.link.pytorch.dispatch.scalar +# import pytensor.link.jax.dispatch.tensor_basic +# import pytensor.link.jax.dispatch.subtensor +# import pytensor.link.jax.dispatch.shape +# import pytensor.link.jax.dispatch.extra_ops +# import pytensor.link.jax.dispatch.nlinalg +# import pytensor.link.jax.dispatch.slinalg +# import pytensor.link.jax.dispatch.random +import pytensor.link.pytorch.dispatch.elemwise +# import pytensor.link.jax.dispatch.scan +# import pytensor.link.jax.dispatch.sparse +# import pytensor.link.jax.dispatch.blockwise + +# isort: on diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py new file mode 100644 index 0000000000..8e752429f1 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -0,0 +1,117 @@ +import warnings +from functools import singledispatch + +# import jax +# import jax.numpy as jnp +import torch +import numpy as np + +from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.ifelse import IfElse +from pytensor.link.utils import fgraph_to_python +from pytensor.raise_op import Assert, CheckAndRaise + + +# if config.floatX == "float64": +# jax.config.update("jax_enable_x64", True) +# else: +# jax.config.update("jax_enable_x64", False) + + +@singledispatch +def pytorch_typify(data, dtype=None, **kwargs): + r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" + if dtype is None: + return data + else: + return torch.tensor(data, dtype=dtype) + + +@pytorch_typify.register(np.ndarray) +def pytorch_typify_ndarray(data, dtype=None, **kwargs): + if len(data.shape) == 0: + return data.item() + return torch.tensor(data, dtype=dtype) + + +@singledispatch +def pytorch_funcify(op, node=None, storage_map=None, **kwargs): + """Create a PyTorch compatible function from an PyTensor `Op`.""" + raise NotImplementedError(f"No PyTorch conversion for the given `Op`: {op}") + + +@pytorch_funcify.register(FunctionGraph) +def pytorch_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="pytorch_funcified_fgraph", + **kwargs, +): + return fgraph_to_python( + fgraph, + pytorch_funcify, + type_conversion_fn=pytorch_typify, + fgraph_name=fgraph_name, + **kwargs, + ) + + +@pytorch_funcify.register(IfElse) +def pytorch_funcify_IfElse(op, **kwargs): + n_outs = op.n_outs + + def ifelse(cond, *args, n_outs=n_outs): + res = torch.where( + cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None + ) + return res if n_outs > 1 else res[0] + + return ifelse + + +@pytorch_funcify.register(Assert) +@pytorch_funcify.register(CheckAndRaise) +def pytorch_funcify_CheckAndRaise(op, **kwargs): + warnings.warn( + f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as JAX tracing would remove it.""", + stacklevel=2, + ) + + def assert_fn(x, *inputs): + return x + + return assert_fn + + +def pytorch_safe_copy(x): + try: + res = x.clone().detach() + except NotImplementedError: + # warnings.warn( + # "`jnp.copy` is not implemented yet. Using the object's `copy` method." + # ) + if hasattr(x, "copy"): + res = torch.tensor(x.copy()) + else: + warnings.warn(f"Object has no `copy` method: {x}") + res = x + + return res + + +@pytorch_funcify.register(DeepCopyOp) +def pytorch_funcify_DeepCopyOp(op, **kwargs): + def deepcopyop(x): + return pytorch_safe_copy(x) + + return deepcopyop + + +@pytorch_funcify.register(ViewOp) +def pytorch_funcify_ViewOp(op, **kwargs): + def viewop(x): + return x + + return viewop diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py new file mode 100644 index 0000000000..5ec3fcc10a --- /dev/null +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -0,0 +1,71 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise +from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad + + +@pytorch_funcify.register(Elemwise) +def pytorch_funcify_Elemwise(op, node, **kwargs): + scalar_op = op.scalar_op + base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) + + def elemwise_fn(*inputs): + Elemwise._check_runtime_broadcast(node, tuple(map(torch.tensor, inputs))) + return base_fn(*inputs) + + return elemwise_fn + + + +@pytorch_funcify.register(DimShuffle) +def pytorch_funcify_DimShuffle(op, **kwargs): + def dimshuffle(x): + res = torch.transpose(x, *op.transposition) + print(res, '-----') + + shape = list(res.shape[: len(op.shuffle)]) + + for augm in op.augment: + shape.insert(augm, 1) + + res = torch.reshape(res, shape) + + if not op.inplace: + res = res.clone().detach() + + return res + + return dimshuffle + + +@pytorch_funcify.register(Softmax) +def pytorch_funcify_Softmax(op, **kwargs): + axis = op.axis + + def softmax(x): + x = torch.tensor(x).cuda() + return torch.nn.functional.softmax(x, dim=axis) + + return softmax + + +@pytorch_funcify.register(SoftmaxGrad) +def pytorch_funcify_SoftmaxGrad(op, **kwargs): + axis = op.axis + + def softmax_grad(dy, sm): + dy_times_sm = dy * sm + return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdims=True) * sm + + return softmax_grad + + +@pytorch_funcify.register(LogSoftmax) +def pytorch_funcify_LogSoftmax(op, **kwargs): + axis = op.axis + + def log_softmax(x): + return torch.nn.functional.log_softmax(x, dim=axis) + + return log_softmax diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py new file mode 100644 index 0000000000..b32a08d9fc --- /dev/null +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -0,0 +1,332 @@ +import functools +import typing +from collections.abc import Callable + +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.scalar import Softplus +from pytensor.scalar.basic import ( + Add, + Cast, + Clip, + Composite, + Identity, + IntDiv, + Mod, + Mul, + ScalarOp, + Second, + Sub, +) +from pytensor.scalar.math import ( + BetaIncInv, + Erf, + Erfc, + Erfcinv, + Erfcx, + Erfinv, + GammaIncCInv, + GammaIncInv, + Iv, + Ive, + Log1mexp, + Psi, + TriGamma, +) + + +def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: str | None = None) -> Callable: + try: + import tensorflow_probability.substrates.jax.math as tfp_jax_math + except ModuleNotFoundError: + raise NotImplementedError( + f"No JAX implementation for Op {op.name}. " + "Implementation is available if TensorFlow Probability is installed" + ) + + if jax_op_name is None: + jax_op_name = op.name + return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name)) + + +def all_inputs_are_scalar(node): + """Check whether all the inputs of an `Elemwise` are scalar values. + + """ + ndims_input = [inp.type.ndim for inp in node.inputs] + are_inputs_scalars = True + for ndim in ndims_input: + try: + if ndim > 0: + are_inputs_scalars = False + except TypeError: + are_inputs_scalars = False + + return are_inputs_scalars + + +@pytorch_funcify.register(ScalarOp) +def pytorch_funcify_ScalarOp(op, node, **kwargs): + """Return pytorch function that implements the same computation as the Scalar Op. + + This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does, + even though it's dispatched on the Scalar Op. + """ + + # We dispatch some PyTensor operators to Python operators + # whenever the inputs are all scalars. + if all_inputs_are_scalar(node): + pytorch_func = pytorch_funcify_scalar_op_via_py_operators(op) + if pytorch_func is not None: + return pytorch_func + + nfunc_spec = getattr(op, "nfunc_spec", None) + if nfunc_spec is None: + raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}") + + func_name = nfunc_spec[0] + print + # if "." in func_name: + # pytorch_func = functools.reduce(getattr, [jax, *func_name.split(".")]) + # else: + pytorch_func = getattr(torch, func_name) + + if len(node.inputs) > op.nfunc_spec[1]: + # Some Scalar Ops accept multiple number of inputs, behaving as a variadic function, + # even though the base Op from `func_name` is specified as a binary Op. + # This happens with `Add`, which can work as a `Sum` for multiple scalars. + pytorch_variadic_func = getattr(torch, op.nfunc_variadic, None) + if not pytorch_variadic_func: + raise NotImplementedError( + f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs" + ) + + def pytorch_func(*args): + return pytorch_variadic_func( + torch.stack(torch.broadcast_tensors(*args), axis=0), axis=0 + ) + + return pytorch_func + + +@functools.singledispatch +def pytorch_funcify_scalar_op_via_py_operators(op): + """Specialized JAX dispatch for Elemwise operations where all inputs are Scalar arrays. + + Scalar (constant) arrays in the JAX backend get lowered to the native types (int, floats), + which can perform better with Python operators, and more importantly, avoid upcasting to array types + not supported by some JAX functions. + """ + return None + + +@pytorch_funcify_scalar_op_via_py_operators.register(Add) +def pytorch_funcify_scalar_Add(op): + def elemwise(*inputs): + return sum(inputs) + + return elemwise + + +@pytorch_funcify_scalar_op_via_py_operators.register(Mul) +def pytorch_funcify_scalar_Mul(op): + import operator + from functools import reduce + + def elemwise(*inputs): + return reduce(operator.mul, inputs, 1) + + return elemwise + + +@pytorch_funcify_scalar_op_via_py_operators.register(Sub) +def pytorch_funcify_scalar_Sub(op): + def elemwise(x, y): + return x - y + + return elemwise + + +@pytorch_funcify_scalar_op_via_py_operators.register(IntDiv) +def pytorch_funcify_scalar_IntDiv(op): + def elemwise(x, y): + return x // y + + return elemwise + + +@pytorch_funcify_scalar_op_via_py_operators.register(Mod) +def pytorch_funcify_scalar_Mod(op): + def elemwise(x, y): + return x % y + + return elemwise + + +@pytorch_funcify.register(Cast) +def pytorch_funcify_Cast(op, **kwargs): + def cast(x): + return torch.tensor(x).astype(op.o_type.dtype) + + return cast + + +@pytorch_funcify.register(Identity) +def pytorch_funcify_Identity(op, **kwargs): + def identity(x): + return x + + return identity + + +@pytorch_funcify.register(Clip) +def pytorch_funcify_Clip(op, **kwargs): + """Register the translation for the `Clip` `Op`. + + PyTensor's `Clip` operator operates differently from NumPy's when the + specified `min` is larger than the `max` so we cannot reuse `jax.numpy.clip` + to maintain consistency with PyTensor. + + """ + + def clip(x, min, max): + return torch.where(x < min, min, torch.where(x > max, max, x)) + + return clip + + +@pytorch_funcify.register(Composite) +def pytorch_funcify_Composite(op, node, vectorize=True, **kwargs): + jax_impl = pytorch_funcify(op.fgraph) + + if len(node.outputs) == 1: + + def composite(*args): + return jax_impl(*args)[0] + + else: + + def composite(*args): + return jax_impl(*args) + + return jnp.vectorize(composite) + + +@pytorch_funcify.register(Second) +def pytorch_funcify_Second(op, **kwargs): + def second(x, y): + _, y = torch.broadcast_tensors(x, y) + return y + + return second + + +@pytorch_funcify.register(GammaIncInv) +def pytorch_funcify_GammaIncInv(op, **kwargs): + gammaincinv = try_import_tfp_jax_op(op, jax_op_name="igammainv") + + return gammaincinv + + +@pytorch_funcify.register(GammaIncCInv) +def pytorch_funcify_GammaIncCInv(op, **kwargs): + gammainccinv = try_import_tfp_jax_op(op, jax_op_name="igammacinv") + + return gammainccinv + + +@pytorch_funcify.register(Erf) +def pytorch_funcify_Erf(op, node, **kwargs): + def erf(x): + return torch.special.erf(x) + + return erf + + +@pytorch_funcify.register(Erfc) +def pytorch_funcify_Erfc(op, **kwargs): + def erfc(x): + return torch.special.erfc(x) + + return erfc + + +@pytorch_funcify.register(Erfinv) +def pytorch_funcify_Erfinv(op, **kwargs): + def erfinv(x): + return torch.special.erfinv(x) + + return erfinv + + +@pytorch_funcify.register(BetaIncInv) +@pytorch_funcify.register(Erfcx) +@pytorch_funcify.register(Erfcinv) +def pytorch_funcify_from_tfp(op, **kwargs): + tfp_jax_op = try_import_tfp_jax_op(op) + + return tfp_jax_op + + +@pytorch_funcify.register(Iv) +def pytorch_funcify_Iv(op, **kwargs): + ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive") + + def iv(v, x): + return ive(v, x) / torch.exp(-torch.abs(torch.real(x))) + + return iv + + +@pytorch_funcify.register(Ive) +def pytorch_funcify_Ive(op, **kwargs): + ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive") + + return ive + + +@pytorch_funcify.register(Log1mexp) +def pytorch_funcify_Log1mexp(op, node, **kwargs): + def log1mexp(x): + return torch.where( + x < torch.log(0.5), torch.log1p(-torch.exp(x)), torch.log(-torch.expm1(x)) + ) + + return log1mexp + + +@pytorch_funcify.register(Psi) +def pytorch_funcify_Psi(op, node, **kwargs): + def psi(x): + return torch.special.digamma(x) + + return psi + + +@pytorch_funcify.register(TriGamma) +def pytorch_funcify_TriGamma(op, node, **kwargs): + def tri_gamma(x): + return torch.special.polygamma(1, x) + + return tri_gamma + + +@pytorch_funcify.register(Softplus) +def pytorch_funcify_Softplus(op, **kwargs): + def softplus(x): + return torch.where( + x < -37.0, + torch.exp(x), + torch.where( + x < 18.0, + torch.log1p(torch.exp(x)), + torch.where( + x < 33.3, + x + torch.exp(-x), + x, + ), + ), + ) + + return softplus diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py new file mode 100644 index 0000000000..0965b2e092 --- /dev/null +++ b/pytensor/link/pytorch/linker.py @@ -0,0 +1,88 @@ +import warnings + +from numpy.random import Generator, RandomState + +from pytensor.compile.sharedvalue import SharedVariable, shared +from pytensor.graph.basic import Constant +from pytensor.link.basic import JITLinker + + +class PytorchLinker(JITLinker): + """A `Linker` that compiles NumPy-based operations using torch.compile.""" + + def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): + from pytensor.link.pytorch.dispatch import pytorch_typify, pytorch_funcify + from pytensor.tensor.random.type import RandomType + + shared_rng_inputs = [ + inp + for inp in fgraph.inputs + if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType)) + ] + + # Replace any shared RNG inputs so that their values can be updated in place + # without affecting the original RNG container. This is necessary because + # JAX does not accept RandomState/Generators as inputs, and they will have to + # be tipyfied + if shared_rng_inputs: + # warnings.warn( + # f"The RandomType SharedVariables {shared_rng_inputs} will not be used " + # f"in the compiled JAX graph. Instead a copy will be used.", + # UserWarning, + # ) + new_shared_rng_inputs = [ + shared(inp.get_value(borrow=False)) for inp in shared_rng_inputs + ] + + fgraph.replace_all( + zip(shared_rng_inputs, new_shared_rng_inputs), + import_missing=True, + reason="PytorchLinker.fgraph_convert", + ) + + for old_inp, new_inp in zip(shared_rng_inputs, new_shared_rng_inputs): + new_inp_storage = [new_inp.get_value(borrow=True)] + storage_map[new_inp] = new_inp_storage + old_inp_storage = storage_map.pop(old_inp) + # Find index of old_inp_storage in input_storage + for input_storage_idx, input_storage_item in enumerate(input_storage): + # We have to establish equality based on identity because input_storage may contain numpy arrays + if input_storage_item is old_inp_storage: + break + else: # no break + raise ValueError() + input_storage[input_storage_idx] = new_inp_storage + # We need to change the order of the inputs of the FunctionGraph + # so that the new input is in the same position as to old one, + # to align with the storage_map. We hope this is safe! + old_inp_fgrap_index = fgraph.inputs.index(old_inp) + fgraph.remove_input( + old_inp_fgrap_index, + reason="PytorchLinker.fgraph_convert", + ) + fgraph.inputs.remove(new_inp) + fgraph.inputs.insert(old_inp_fgrap_index, new_inp) + + return pytorch_funcify( + fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs + ) + + def jit_compile(self, fn): + import torch + + return torch.compile(fn) + + def create_thunk_inputs(self, storage_map): + from pytensor.link.pytorch.dispatch import pytorch_typify + + thunk_inputs = [] + for n in self.fgraph.inputs: + sinput = storage_map[n] + if isinstance(sinput[0], RandomState | Generator): + new_value = pytorch_typify( + sinput[0], dtype=getattr(sinput[0], "dtype", None) + ) + sinput[0] = new_value + thunk_inputs.append(sinput) + + return thunk_inputs diff --git a/pytorch_sandbox_example.ipynb b/pytorch_sandbox_example.ipynb new file mode 100644 index 0000000000..1f14010432 --- /dev/null +++ b/pytorch_sandbox_example.ipynb @@ -0,0 +1,214 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Log [id A] 2\n", + " └─ Sub [id B] 1\n", + " ├─ ExpandDims{axis=0} [id C] 0\n", + " │ └─ 1 [id D]\n", + " └─ x [id E]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax\n", + "import torch\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "import numpy as np\n", + "\n", + "from pytensor.graph.fg import FunctionGraph\n", + "from pytensor.link.jax.dispatch import jax_funcify\n", + "from pytensor.link.pytorch.dispatch import pytorch_funcify\n", + "from pytensor.compile.mode import get_mode\n", + "\n", + "from pytensor.graph.rewriting.utils import rewrite_graph\n", + "\n", + "x = pt.vector(\"x\")\n", + "one_mx = 1 - x\n", + "out = pt.log(one_mx)\n", + "\n", + "fg = FunctionGraph(inputs=None, outputs=[out])\n", + "\n", + "pytensor.dprint(fg)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Log1p [id A] 1\n", + " └─ Neg [id B] 0\n", + " └─ x [id C]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "opt_fg = rewrite_graph(fg, include=(\"canonicalize\", \"stabilize\", \"specialize\"))\n", + "pytensor.dprint(opt_fg)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "pytorch_fn = pytorch_funcify(opt_fg)\n", + "jax_fn = jax_funcify(opt_fg)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/harshvir/miniconda3/envs/pytensor-dev/lib/python3.11/site-packages/torch/_dynamo/utils.py:1764: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " return node.target(*args, **kwargs)\n", + "/home/harshvir/miniconda3/envs/pytensor-dev/lib/python3.11/site-packages/torch/fx/interpreter.py:274: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " return target(*args, **kwargs)\n", + "/home/harshvir/miniconda3/envs/pytensor-dev/lib/python3.11/site-packages/torch/fx/interpreter.py:274: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " return target(*args, **kwargs)\n", + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "JAX output = [-2.30258509 -2.30258509]\n", + "Pytorch output = tensor([-2.3026, -2.3026], device='cuda:0')\n" + ] + } + ], + "source": [ + "pytorch_compiled_fn = torch.compile(pytorch_fn)\n", + "pytorch_out = pytorch_compiled_fn(torch.tensor([0.9, 0.9]).cuda())[0]\n", + "\n", + "jax_compiled_fn = jax.jit(jax_fn)\n", + "jax_out = jax_compiled_fn(np.array([0.9, 0.9]))[0]\n", + "\n", + "print(f'JAX output = {jax_out}')\n", + "print(f'Pytorch output = {pytorch_out}')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0;31mSignature:\u001b[0m \u001b[0mpytorch_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m \n", + "\u001b[0;31mSource:\u001b[0m \n", + "\u001b[0;32mdef\u001b[0m \u001b[0mpytorch_funcified_fgraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;31m# Neg(x)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtensor_variable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;31m# Log1p(Neg.0)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtensor_variable_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise_fn_1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_variable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtensor_variable_1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFile:\u001b[0m /tmp/tmplb4e259q\n", + "\u001b[0;31mType:\u001b[0m function" + ] + } + ], + "source": [ + "??pytorch_fn" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0;31mSignature:\u001b[0m \u001b[0mjax_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m \n", + "\u001b[0;31mSource:\u001b[0m \n", + "\u001b[0;32mdef\u001b[0m \u001b[0mjax_funcified_fgraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;31m# Neg(x)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtensor_variable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;31m# Log1p(Neg.0)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtensor_variable_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise_fn_1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_variable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtensor_variable_1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFile:\u001b[0m /tmp/tmpy7iyurim\n", + "\u001b[0;31mType:\u001b[0m function" + ] + } + ], + "source": [ + "??jax_fn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytensor-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 629d00bf26416934bfd0c35736efe9be780fb53a Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Mon, 13 May 2024 16:23:11 +0530 Subject: [PATCH 02/45] update variable names, docstrings --- pytensor/link/pytorch/dispatch/basic.py | 4 --- pytensor/link/pytorch/dispatch/scalar.py | 24 ++++++------- pytorch_sandbox_example.ipynb | 43 ++++++++++++------------ 3 files changed, 31 insertions(+), 40 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 8e752429f1..162376ee22 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -74,10 +74,6 @@ def ifelse(cond, *args, n_outs=n_outs): @pytorch_funcify.register(Assert) @pytorch_funcify.register(CheckAndRaise) def pytorch_funcify_CheckAndRaise(op, **kwargs): - warnings.warn( - f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as JAX tracing would remove it.""", - stacklevel=2, - ) def assert_fn(x, *inputs): return x diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index b32a08d9fc..8acdb3aae9 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -184,10 +184,6 @@ def identity(x): def pytorch_funcify_Clip(op, **kwargs): """Register the translation for the `Clip` `Op`. - PyTensor's `Clip` operator operates differently from NumPy's when the - specified `min` is larger than the `max` so we cannot reuse `jax.numpy.clip` - to maintain consistency with PyTensor. - """ def clip(x, min, max): @@ -196,21 +192,21 @@ def clip(x, min, max): return clip -@pytorch_funcify.register(Composite) -def pytorch_funcify_Composite(op, node, vectorize=True, **kwargs): - jax_impl = pytorch_funcify(op.fgraph) +# @pytorch_funcify.register(Composite) +# def pytorch_funcify_Composite(op, node, vectorize=True, **kwargs): +# jax_impl = pytorch_funcify(op.fgraph) - if len(node.outputs) == 1: +# if len(node.outputs) == 1: - def composite(*args): - return jax_impl(*args)[0] +# def composite(*args): +# return jax_impl(*args)[0] - else: +# else: - def composite(*args): - return jax_impl(*args) +# def composite(*args): +# return jax_impl(*args) - return jnp.vectorize(composite) +# return jnp.vectorize(composite) @pytorch_funcify.register(Second) diff --git a/pytorch_sandbox_example.ipynb b/pytorch_sandbox_example.ipynb index 1f14010432..011c69f15d 100644 --- a/pytorch_sandbox_example.ipynb +++ b/pytorch_sandbox_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -19,10 +19,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 1, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -67,10 +67,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 2, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -82,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -92,9 +92,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 10, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "JAX output = [-2.30258509 -2.30258509]\n", + "Pytorch output = tensor([-2.3026, -2.3026], device='cuda:0')\n" + ] + }, { "name": "stderr", "output_type": "stream", @@ -104,16 +112,7 @@ "/home/harshvir/miniconda3/envs/pytensor-dev/lib/python3.11/site-packages/torch/fx/interpreter.py:274: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " return target(*args, **kwargs)\n", "/home/harshvir/miniconda3/envs/pytensor-dev/lib/python3.11/site-packages/torch/fx/interpreter.py:274: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " return target(*args, **kwargs)\n", - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "JAX output = [-2.30258509 -2.30258509]\n", - "Pytorch output = tensor([-2.3026, -2.3026], device='cuda:0')\n" + " return target(*args, **kwargs)\n" ] } ], @@ -130,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -146,7 +145,7 @@ "\u001b[0;34m\u001b[0m \u001b[0;31m# Log1p(Neg.0)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtensor_variable_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise_fn_1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_variable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtensor_variable_1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mFile:\u001b[0m /tmp/tmplb4e259q\n", + "\u001b[0;31mFile:\u001b[0m /tmp/tmpa33e4b_h\n", "\u001b[0;31mType:\u001b[0m function" ] } @@ -157,7 +156,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -173,7 +172,7 @@ "\u001b[0;34m\u001b[0m \u001b[0;31m# Log1p(Neg.0)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtensor_variable_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise_fn_1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_variable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtensor_variable_1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mFile:\u001b[0m /tmp/tmpy7iyurim\n", + "\u001b[0;31mFile:\u001b[0m /tmp/tmpldlcl44u\n", "\u001b[0;31mType:\u001b[0m function" ] } From 3eceb568431da206f78c00b8ca11a5731c734568 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 18 May 2024 01:14:18 +0530 Subject: [PATCH 03/45] Avoid numpy conversion of torch Tensors --- pytensor/link/pytorch/linker.py | 5 +---- pytensor/tensor/basic.py | 6 ++++++ pytensor/tensor/type.py | 5 +++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 0965b2e092..605c930d1b 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,9 +1,6 @@ -import warnings - from numpy.random import Generator, RandomState from pytensor.compile.sharedvalue import SharedVariable, shared -from pytensor.graph.basic import Constant from pytensor.link.basic import JITLinker @@ -11,7 +8,7 @@ class PytorchLinker(JITLinker): """A `Linker` that compiles NumPy-based operations using torch.compile.""" def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): - from pytensor.link.pytorch.dispatch import pytorch_typify, pytorch_funcify + from pytensor.link.pytorch.dispatch import pytorch_funcify from pytensor.tensor.random.type import RandomType shared_rng_inputs = [ diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 518b55da99..dacb52a63a 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -14,6 +14,7 @@ from typing import cast as type_cast import numpy as np +import torch from numpy.core.multiarray import normalize_axis_index from numpy.core.numeric import normalize_axis_tuple @@ -185,6 +186,11 @@ def _as_tensor_numbers(x, name, ndim, dtype=None, **kwargs): return constant(x, name=name, ndim=ndim, dtype=dtype) +@_as_tensor_variable.register(torch.Tensor) +def _as_tensor_numbers(x, name, ndim, dtype=None, **kwargs): + return constant(x, name=name, ndim=ndim, dtype=dtype) + + @_as_tensor_variable.register(bool) def _as_tensor_bool(x, name, ndim, **kwargs): raise TypeError( diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index b55d226471..721e37bd37 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal, Optional import numpy as np +import torch import pytensor from pytensor import scalar as ps @@ -160,6 +161,10 @@ def filter(self, data, strict=False, allow_downcast=None): # however, casting it would defeat the purpose of not # loading the whole data into memory pass + + elif isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray) and (data.dtype == self.numpy_dtype): if data.dtype.num != self.numpy_dtype.num: data = _asarray(data, dtype=self.dtype) From 3cde964ac5b579d6709238a5a5d1cc98ba355f91 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 18 May 2024 01:20:24 +0530 Subject: [PATCH 04/45] Fix typify and CheckAndRaise --- pytensor/link/pytorch/dispatch/basic.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 162376ee22..5d89137d90 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -4,22 +4,14 @@ # import jax # import jax.numpy as jnp import torch -import numpy as np from pytensor.compile.ops import DeepCopyOp, ViewOp -from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import Assert, CheckAndRaise -# if config.floatX == "float64": -# jax.config.update("jax_enable_x64", True) -# else: -# jax.config.update("jax_enable_x64", False) - - @singledispatch def pytorch_typify(data, dtype=None, **kwargs): r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" @@ -29,10 +21,10 @@ def pytorch_typify(data, dtype=None, **kwargs): return torch.tensor(data, dtype=dtype) -@pytorch_typify.register(np.ndarray) -def pytorch_typify_ndarray(data, dtype=None, **kwargs): - if len(data.shape) == 0: - return data.item() +@pytorch_typify.register(torch.Tensor) +def pytorch_typify_tensor(data, dtype=None, **kwargs): + # if len(data.shape) == 0: + # return data.item() return torch.tensor(data, dtype=dtype) @@ -74,8 +66,9 @@ def ifelse(cond, *args, n_outs=n_outs): @pytorch_funcify.register(Assert) @pytorch_funcify.register(CheckAndRaise) def pytorch_funcify_CheckAndRaise(op, **kwargs): - - def assert_fn(x, *inputs): + def assert_fn(x, *conditions): + for cond in conditions: + assert cond.item() return x return assert_fn From c003aa5fdd97670bed6fd24b8edf34159899c8c3 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 18 May 2024 01:27:34 +0530 Subject: [PATCH 05/45] Fix Elemwise Ops --- pytensor/link/pytorch/dispatch/elemwise.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 5ec3fcc10a..9ba186cfa5 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -1,7 +1,7 @@ import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify -from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise +from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -11,18 +11,16 @@ def pytorch_funcify_Elemwise(op, node, **kwargs): base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) def elemwise_fn(*inputs): - Elemwise._check_runtime_broadcast(node, tuple(map(torch.tensor, inputs))) + # Elemwise._check_runtime_broadcast(node, tuple(map(torch.tensor, inputs))) return base_fn(*inputs) return elemwise_fn - @pytorch_funcify.register(DimShuffle) def pytorch_funcify_DimShuffle(op, **kwargs): def dimshuffle(x): res = torch.transpose(x, *op.transposition) - print(res, '-----') shape = list(res.shape[: len(op.shuffle)]) @@ -32,7 +30,7 @@ def dimshuffle(x): res = torch.reshape(res, shape) if not op.inplace: - res = res.clone().detach() + res = res.clone() return res @@ -44,7 +42,6 @@ def pytorch_funcify_Softmax(op, **kwargs): axis = op.axis def softmax(x): - x = torch.tensor(x).cuda() return torch.nn.functional.softmax(x, dim=axis) return softmax From 8dc406e4d5620e55c4b548e9f2417aad46473833 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 18 May 2024 02:55:46 +0530 Subject: [PATCH 06/45] Fix Scalar Ops --- pytensor/link/pytorch/dispatch/scalar.py | 97 ++---------------------- 1 file changed, 5 insertions(+), 92 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 8acdb3aae9..293839ae02 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -1,4 +1,3 @@ -import functools import typing from collections.abc import Callable @@ -7,17 +6,11 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar import Softplus from pytensor.scalar.basic import ( - Add, Cast, Clip, - Composite, Identity, - IntDiv, - Mod, - Mul, ScalarOp, Second, - Sub, ) from pytensor.scalar.math import ( BetaIncInv, @@ -51,9 +44,7 @@ def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: str | None = None) -> Calla def all_inputs_are_scalar(node): - """Check whether all the inputs of an `Elemwise` are scalar values. - - """ + """Check whether all the inputs of an `Elemwise` are scalar values.""" ndims_input = [inp.type.ndim for inp in node.inputs] are_inputs_scalars = True for ndim in ndims_input: @@ -74,22 +65,12 @@ def pytorch_funcify_ScalarOp(op, node, **kwargs): even though it's dispatched on the Scalar Op. """ - # We dispatch some PyTensor operators to Python operators - # whenever the inputs are all scalars. - if all_inputs_are_scalar(node): - pytorch_func = pytorch_funcify_scalar_op_via_py_operators(op) - if pytorch_func is not None: - return pytorch_func - nfunc_spec = getattr(op, "nfunc_spec", None) if nfunc_spec is None: raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}") func_name = nfunc_spec[0] - print - # if "." in func_name: - # pytorch_func = functools.reduce(getattr, [jax, *func_name.split(".")]) - # else: + pytorch_func = getattr(torch, func_name) if len(node.inputs) > op.nfunc_spec[1]: @@ -110,60 +91,6 @@ def pytorch_func(*args): return pytorch_func -@functools.singledispatch -def pytorch_funcify_scalar_op_via_py_operators(op): - """Specialized JAX dispatch for Elemwise operations where all inputs are Scalar arrays. - - Scalar (constant) arrays in the JAX backend get lowered to the native types (int, floats), - which can perform better with Python operators, and more importantly, avoid upcasting to array types - not supported by some JAX functions. - """ - return None - - -@pytorch_funcify_scalar_op_via_py_operators.register(Add) -def pytorch_funcify_scalar_Add(op): - def elemwise(*inputs): - return sum(inputs) - - return elemwise - - -@pytorch_funcify_scalar_op_via_py_operators.register(Mul) -def pytorch_funcify_scalar_Mul(op): - import operator - from functools import reduce - - def elemwise(*inputs): - return reduce(operator.mul, inputs, 1) - - return elemwise - - -@pytorch_funcify_scalar_op_via_py_operators.register(Sub) -def pytorch_funcify_scalar_Sub(op): - def elemwise(x, y): - return x - y - - return elemwise - - -@pytorch_funcify_scalar_op_via_py_operators.register(IntDiv) -def pytorch_funcify_scalar_IntDiv(op): - def elemwise(x, y): - return x // y - - return elemwise - - -@pytorch_funcify_scalar_op_via_py_operators.register(Mod) -def pytorch_funcify_scalar_Mod(op): - def elemwise(x, y): - return x % y - - return elemwise - - @pytorch_funcify.register(Cast) def pytorch_funcify_Cast(op, **kwargs): def cast(x): @@ -182,12 +109,10 @@ def identity(x): @pytorch_funcify.register(Clip) def pytorch_funcify_Clip(op, **kwargs): - """Register the translation for the `Clip` `Op`. - - """ + """Register the translation for the `Clip` `Op`.""" def clip(x, min, max): - return torch.where(x < min, min, torch.where(x > max, max, x)) + return torch.clip(x, min, max) return clip @@ -311,18 +236,6 @@ def tri_gamma(x): @pytorch_funcify.register(Softplus) def pytorch_funcify_Softplus(op, **kwargs): def softplus(x): - return torch.where( - x < -37.0, - torch.exp(x), - torch.where( - x < 18.0, - torch.log1p(torch.exp(x)), - torch.where( - x < 33.3, - x + torch.exp(-x), - x, - ), - ), - ) + return torch.nn.functional.softplus(x) return softplus From a8f6ddbf021c390034e9ab3bcaf355b2fc241831 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 18 May 2024 03:11:27 +0530 Subject: [PATCH 07/45] Fix ruff-format --- pytensor/link/pytorch/dispatch/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index b8a4547dc5..e68ca37f84 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -3,6 +3,7 @@ # # Load dispatch specializations import pytensor.link.pytorch.dispatch.scalar + # import pytensor.link.jax.dispatch.tensor_basic # import pytensor.link.jax.dispatch.subtensor # import pytensor.link.jax.dispatch.shape From 9d535f529c67f99fe600b961671b589a3c88ed6f Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 24 May 2024 00:16:18 +0530 Subject: [PATCH 08/45] Initial setup for pytorch tests --- .github/workflows/test.yml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 18628cfda9..3c5a07b244 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -76,6 +76,7 @@ jobs: float32: [0,1] install-numba: [0] install-jax: [0] + install-torch: [0] part: - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" @@ -116,6 +117,11 @@ jobs: fast-compile: 0 float32: 0 part: "tests/link/jax" + - install-torch: 1 + python-version: "3.10" + fast-compile: 0 + float32: 0 + # part: "tests/link/pytorch" steps: - uses: actions/checkout@v4 with: @@ -143,6 +149,7 @@ jobs: mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi + if [[ $INSTALL_TORCH == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch && pip install tensorflow-probability; fi pip install -e ./ mamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' @@ -151,6 +158,7 @@ jobs: PYTHON_VERSION: ${{ matrix.python-version }} INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} + INSTALL_TORCH: ${{ matrix.install-torch}} - name: Run tests shell: bash -l {0} @@ -195,7 +203,7 @@ jobs: - name: Install dependencies shell: bash -l {0} run: | - mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark + mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytorch pytest-benchmark pip install -e ./ mamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' @@ -264,3 +272,4 @@ jobs: directory: ./coverage/ fail_ci_if_error: true token: ${{ secrets.CODECOV_TOKEN }} + From c5600daae1c95ae877664a3557b06070c9284421 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 24 May 2024 00:32:52 +0530 Subject: [PATCH 09/45] Fix mode parameters for pytorch --- pytensor/compile/mode.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 86bff7b478..fd2ada7947 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -473,7 +473,6 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "BlasOpt", "fusion", "inplace", - "local_uint_constant_indices", ], ), ) From 54b6248abf1c7eb4a202b3fcedc4c8cca9002b33 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 24 May 2024 01:45:54 +0530 Subject: [PATCH 10/45] Prevent conversion of scalars to numpy --- pytensor/scalar/basic.py | 4 ++++ pytensor/tensor/sharedvar.py | 29 +++++++++++++++++++++-------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 5d7ba66748..d5b3544cad 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -19,6 +19,7 @@ from typing import Any, TypeAlias import numpy as np +import torch import pytensor from pytensor import printing @@ -242,6 +243,9 @@ def convert(x, dtype=None): if isinstance(x, np.ma.MaskedArray): raise NotImplementedError("MaskedArrays are not supported") + if isinstance(x, torch.Tensor): + return x + if dtype is not None: # in this case, the semantics are that the caller is forcing the dtype x_ = _asarray(x, dtype=dtype) diff --git a/pytensor/tensor/sharedvar.py b/pytensor/tensor/sharedvar.py index dad1751f9b..5ec1c520ce 100644 --- a/pytensor/tensor/sharedvar.py +++ b/pytensor/tensor/sharedvar.py @@ -1,6 +1,7 @@ import warnings import numpy as np +import torch from pytensor.compile import SharedVariable, shared_constructor from pytensor.misc.safe_asarray import _asarray @@ -57,6 +58,7 @@ def _get_vector_length_TensorSharedVariable(var_inst, var): return len(var.get_value(borrow=True)) +@shared_constructor.register(torch.Tensor) @shared_constructor.register(np.ndarray) def tensor_constructor( value, @@ -102,6 +104,7 @@ def tensor_constructor( ) +@shared_constructor.register(torch.Tensor) @shared_constructor.register(np.number) @shared_constructor.register(float) @shared_constructor.register(int) @@ -128,16 +131,26 @@ def scalar_constructor( dtype = np.asarray(value).dtype dtype = str(dtype) - value = _asarray(value, dtype=dtype) + if not isinstance(value, torch.Tensor): + value = _asarray(value, dtype=dtype) tensor_type = TensorType(dtype=str(value.dtype), shape=()) + if isinstance(value, torch.Tensor): + rval = TensorSharedVariable( + type=tensor_type, + value=value, + name=name, + strict=strict, + allow_downcast=allow_downcast, + ) # Do not pass the dtype to asarray because we want this to fail if # strict is True and the types do not match. - rval = TensorSharedVariable( - type=tensor_type, - value=np.array(value, copy=True), - name=name, - strict=strict, - allow_downcast=allow_downcast, - ) + else: + rval = TensorSharedVariable( + type=tensor_type, + value=np.array(value, copy=True), + name=name, + strict=strict, + allow_downcast=allow_downcast, + ) return rval From 19454b3f200e996b23c4b3d2308c339674922967 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 24 May 2024 01:52:31 +0530 Subject: [PATCH 11/45] Update TensorConstantSignature and map dtypes to Tensor types --- pytensor/tensor/type.py | 7 ++++-- pytensor/tensor/variable.py | 44 ++++++++++++++++++++++--------------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 721e37bd37..be3e332b7f 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -103,10 +103,13 @@ def __init__( if str(dtype) == "floatX": self.dtype = config.floatX else: - if np.obj2sctype(dtype) is None: + if np.obj2sctype(dtype) is None and "torch" not in str(dtype): raise TypeError(f"Invalid dtype: {dtype}") - self.dtype = np.dtype(dtype).name + if "torch" in str(dtype): + self.dtype = str(dtype).split(".")[-1] + else: + self.dtype = np.dtype(dtype).name def parse_bcast_and_shape(s): if isinstance(s, bool | np.bool_): diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index e881331017..5c5ce1391e 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -6,6 +6,7 @@ from typing import TypeVar import numpy as np +import torch from pytensor import tensor as pt from pytensor.configdefaults import config @@ -959,24 +960,31 @@ def __eq__(self, other): self.no_nan # Ensure has_nan is computed. # Note that in the comparisons below, the elementwise comparisons # come last because they are the most expensive checks. - if self.has_nan: - other.no_nan # Ensure has_nan is computed. - return ( - other.has_nan - and self.sum == other.sum - and (self.no_nan.mask == other.no_nan.mask).all() - and - # Note that the second test below (==) may crash e.g. for - # a single scalar NaN value, so we do not run it when all - # values are missing. - (self.no_nan.mask.all() or (self.no_nan == other.no_nan).all()) - ) - else: - # Simple case where we do not need to worry about NaN values. - # (note that if there are NaN values in d1, this will return - # False, which is why we do not bother with testing `other.has_nan` - # here). - return (self.sum == other.sum) and np.all(d0 == d1) + + # Check for pytorch tensor + if not isinstance(self[1], torch.Tensor): + if self.has_nan: + other.no_nan # Ensure has_nan is computed. + return ( + other.has_nan + and self.sum == other.sum + and (self.no_nan.mask == other.no_nan.mask).all() + and + # Note that the second test below (==) may crash e.g. for + # a single scalar NaN value, so we do not run it when all + # values are missing. + (self.no_nan.mask.all() or (self.no_nan == other.no_nan).all()) + ) + else: + # Simple case where we do not need to worry about NaN values. + # (note that if there are NaN values in d1, this will return + # False, which is why we do not bother with testing `other.has_nan` + # here). + return (self.sum == other.sum) and np.all(d0 == d1) + + return (self.sum == other.sum) and torch.all( + d0 == d1 + ) # With pytorch there is no `has_nan` attribute def __ne__(self, other): return not self == other From 92d7114e5695181e89322b70277c086ee53530b4 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 24 May 2024 02:36:33 +0530 Subject: [PATCH 12/45] Add tests for basic ops --- pytensor/link/pytorch/dispatch/basic.py | 13 +- tests/link/pytorch/__init__.py | 0 tests/link/pytorch/test_basic.py | 192 ++++++++++++++++++++++++ 3 files changed, 198 insertions(+), 7 deletions(-) create mode 100644 tests/link/pytorch/__init__.py create mode 100644 tests/link/pytorch/test_basic.py diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 5d89137d90..09b33f7a25 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -1,15 +1,13 @@ import warnings from functools import singledispatch -# import jax -# import jax.numpy as jnp import torch from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse from pytensor.link.utils import fgraph_to_python -from pytensor.raise_op import Assert, CheckAndRaise +from pytensor.raise_op import CheckAndRaise @singledispatch @@ -56,14 +54,15 @@ def pytorch_funcify_IfElse(op, **kwargs): def ifelse(cond, *args, n_outs=n_outs): res = torch.where( - cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None + cond, + args[:n_outs][0], + args[n_outs:][0], ) - return res if n_outs > 1 else res[0] + return res return ifelse -@pytorch_funcify.register(Assert) @pytorch_funcify.register(CheckAndRaise) def pytorch_funcify_CheckAndRaise(op, **kwargs): def assert_fn(x, *conditions): @@ -76,7 +75,7 @@ def assert_fn(x, *conditions): def pytorch_safe_copy(x): try: - res = x.clone().detach() + res = x.clone() except NotImplementedError: # warnings.warn( # "`jnp.copy` is not implemented yet. Using the object's `copy` method." diff --git a/tests/link/pytorch/__init__.py b/tests/link/pytorch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py new file mode 100644 index 0000000000..e8d358507a --- /dev/null +++ b/tests/link/pytorch/test_basic.py @@ -0,0 +1,192 @@ +from collections.abc import Callable, Iterable +from functools import partial + +import pytest + +from pytensor.compile.function import function +from pytensor.compile.mode import get_mode +from pytensor.compile.sharedvalue import SharedVariable, shared +from pytensor.configdefaults import config +from pytensor.graph.basic import Apply +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import Op, get_test_value +from pytensor.ifelse import ifelse +from pytensor.raise_op import assert_op +from pytensor.tensor.type import dscalar, scalar, vector + + +torch = pytest.importorskip("torch") + + +pytorch_mode = get_mode("PYTORCH") +py_mode = get_mode("FAST_COMPILE") + + +def compare_pytorch_and_py( + fgraph: FunctionGraph, + test_inputs: Iterable, + assert_fn: Callable | None = None, + must_be_device_array: bool = True, + pytorch_mode=pytorch_mode, + py_mode=py_mode, +): + """Function to compare python graph output and pytorch compiled output for testing equality + + Parameters + ---------- + fgraph: FunctionGraph + PyTensor function Graph object + test_inputs: iter + Numerical inputs for testing the function graph + assert_fn: func, opt + Assert function used to check for equality between python and pytorch. If not + provided uses torch.testing.assert_close + must_be_device_array: Bool + Checks if torch.device.type is cuda + + + """ + if assert_fn is None: + assert_fn = partial(torch.testing.assert_close) + + fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] + + pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode) + pytorch_res = pytensor_torch_fn(*test_inputs) + + if must_be_device_array: + if isinstance(pytorch_res, list): + assert all(isinstance(res, torch.Tensor) for res in pytorch_res) + else: + assert pytorch_res.device.type == "cuda" + + pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) + py_res = pytensor_py_fn(*test_inputs) + + if len(fgraph.outputs) > 1: + for j, p in zip(pytorch_res, py_res): + assert_fn(j, p) + else: + assert_fn(pytorch_res, py_res) + + return pytensor_torch_fn, pytorch_res + + +def test_pytorch_FunctionGraph_once(): + """Make sure that an output is only computed once when it's referenced multiple times.""" + from pytensor.link.pytorch.dispatch import pytorch_funcify + + x = vector("x") + y = vector("y") + + class TestOp(Op): + def __init__(self): + self.called = 0 + + def make_node(self, *args): + return Apply(self, list(args), [x.type() for x in args]) + + def perform(self, inputs, outputs): + for i, inp in enumerate(inputs): + outputs[i][0] = inp[0] + + @pytorch_funcify.register(TestOp) + def pytorch_funcify_TestOp(op, **kwargs): + def func(*args, op=op): + op.called += 1 + return list(args) + + return func + + op1 = TestOp() + op2 = TestOp() + + q, r = op1(x, y) + outs = op2(q + r, q + r) + + out_fg = FunctionGraph([x, y], outs, clone=False) + assert len(out_fg.outputs) == 2 + + out_jx = pytorch_funcify(out_fg) + + x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX)) + y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX)) + + res = out_jx(x_val, y_val) + assert len(res) == 2 + assert op1.called == 1 + assert op2.called == 1 + + res = out_jx(x_val, y_val) + assert len(res) == 2 + assert op1.called == 2 + assert op2.called == 2 + + +def test_shared(): + a = shared(torch.tensor([1, 2, 3], dtype=getattr(torch, config.floatX))) + pytensor_torch_fn = function([], a, mode="PYTORCH") + pytorch_res = pytensor_torch_fn() + + assert isinstance(pytorch_res, torch.Tensor) + torch.testing.assert_close(pytorch_res, a.get_value()) + + pytensor_torch_fn = function([], a * 2, mode="PYTORCH") + pytorch_res = pytensor_torch_fn() + + assert isinstance(pytorch_res, torch.Tensor) + torch.testing.assert_close(pytorch_res, a.get_value() * 2) + + new_a_value = torch.tensor([3, 4, 5], dtype=getattr(torch, config.floatX)) + a.set_value(new_a_value) + + pytorch_res = pytensor_torch_fn() + assert isinstance(pytorch_res, torch.Tensor) + torch.testing.assert_close(pytorch_res, new_a_value * 2) + + +def test_shared_updates(): + a = shared(0) + + pytensor_torch_fn = function([], a, updates={a: a + 1}, mode="PYTORCH") + res1, res2 = pytensor_torch_fn(), pytensor_torch_fn() + assert res1 == 0 + assert res2 == 1 + assert a.get_value() == 2 + + a.set_value(5) + res1, res2 = pytensor_torch_fn(), pytensor_torch_fn() + assert res1 == 5 + assert res2 == 6 + assert a.get_value() == 7 + + +def test_pytorch_ifelse(): + true_vals = torch.tensor([1, 2, 3]) + false_vals = torch.tensor([-1, -2, -3]) + + x = ifelse(torch.tensor(True), true_vals, false_vals) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) + + a = dscalar("a") + a.tag.test_value = 0.2 + x = ifelse(a < 0.5, true_vals, false_vals) + x_fg = FunctionGraph([a], [x]) + + compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs]) + + +def test_pytorch_checkandraise(): + p = scalar() + p.tag.test_value = 0 + + res = assert_op(p, p < 1.0) + + function((p,), res, mode=pytorch_mode) + + +def set_test_value(x, v): + x.tag.test_value = v + return x From 5aae0e5e1239a01e62fa275ea6e8822e54c7196d Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 29 May 2024 15:21:33 +0530 Subject: [PATCH 13/45] Remove torch from user facing API --- pytensor/scalar/basic.py | 4 ---- pytensor/tensor/sharedvar.py | 31 +++++++------------------- pytensor/tensor/type.py | 11 ++------- pytensor/tensor/variable.py | 43 +++++++++++++++--------------------- 4 files changed, 28 insertions(+), 61 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index d5b3544cad..5d7ba66748 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -19,7 +19,6 @@ from typing import Any, TypeAlias import numpy as np -import torch import pytensor from pytensor import printing @@ -243,9 +242,6 @@ def convert(x, dtype=None): if isinstance(x, np.ma.MaskedArray): raise NotImplementedError("MaskedArrays are not supported") - if isinstance(x, torch.Tensor): - return x - if dtype is not None: # in this case, the semantics are that the caller is forcing the dtype x_ = _asarray(x, dtype=dtype) diff --git a/pytensor/tensor/sharedvar.py b/pytensor/tensor/sharedvar.py index 5ec1c520ce..bb8a62cf2f 100644 --- a/pytensor/tensor/sharedvar.py +++ b/pytensor/tensor/sharedvar.py @@ -1,7 +1,6 @@ import warnings import numpy as np -import torch from pytensor.compile import SharedVariable, shared_constructor from pytensor.misc.safe_asarray import _asarray @@ -58,7 +57,6 @@ def _get_vector_length_TensorSharedVariable(var_inst, var): return len(var.get_value(borrow=True)) -@shared_constructor.register(torch.Tensor) @shared_constructor.register(np.ndarray) def tensor_constructor( value, @@ -104,7 +102,6 @@ def tensor_constructor( ) -@shared_constructor.register(torch.Tensor) @shared_constructor.register(np.number) @shared_constructor.register(float) @shared_constructor.register(int) @@ -131,26 +128,14 @@ def scalar_constructor( dtype = np.asarray(value).dtype dtype = str(dtype) - if not isinstance(value, torch.Tensor): - value = _asarray(value, dtype=dtype) + value = _asarray(value, dtype=dtype) tensor_type = TensorType(dtype=str(value.dtype), shape=()) - if isinstance(value, torch.Tensor): - rval = TensorSharedVariable( - type=tensor_type, - value=value, - name=name, - strict=strict, - allow_downcast=allow_downcast, - ) - # Do not pass the dtype to asarray because we want this to fail if - # strict is True and the types do not match. - else: - rval = TensorSharedVariable( - type=tensor_type, - value=np.array(value, copy=True), - name=name, - strict=strict, - allow_downcast=allow_downcast, - ) + rval = TensorSharedVariable( + type=tensor_type, + value=np.array(value, copy=True), + name=name, + strict=strict, + allow_downcast=allow_downcast, + ) return rval diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index be3e332b7f..cea1f24216 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Literal, Optional import numpy as np -import torch import pytensor from pytensor import scalar as ps @@ -103,13 +102,10 @@ def __init__( if str(dtype) == "floatX": self.dtype = config.floatX else: - if np.obj2sctype(dtype) is None and "torch" not in str(dtype): + if np.obj2sctype(dtype) is None: raise TypeError(f"Invalid dtype: {dtype}") - if "torch" in str(dtype): - self.dtype = str(dtype).split(".")[-1] - else: - self.dtype = np.dtype(dtype).name + self.dtype = np.dtype(dtype).name def parse_bcast_and_shape(s): if isinstance(s, bool | np.bool_): @@ -165,9 +161,6 @@ def filter(self, data, strict=False, allow_downcast=None): # loading the whole data into memory pass - elif isinstance(data, torch.Tensor): - return data - elif isinstance(data, np.ndarray) and (data.dtype == self.numpy_dtype): if data.dtype.num != self.numpy_dtype.num: data = _asarray(data, dtype=self.dtype) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 5c5ce1391e..00fad89c7f 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -6,7 +6,6 @@ from typing import TypeVar import numpy as np -import torch from pytensor import tensor as pt from pytensor.configdefaults import config @@ -961,30 +960,24 @@ def __eq__(self, other): # Note that in the comparisons below, the elementwise comparisons # come last because they are the most expensive checks. - # Check for pytorch tensor - if not isinstance(self[1], torch.Tensor): - if self.has_nan: - other.no_nan # Ensure has_nan is computed. - return ( - other.has_nan - and self.sum == other.sum - and (self.no_nan.mask == other.no_nan.mask).all() - and - # Note that the second test below (==) may crash e.g. for - # a single scalar NaN value, so we do not run it when all - # values are missing. - (self.no_nan.mask.all() or (self.no_nan == other.no_nan).all()) - ) - else: - # Simple case where we do not need to worry about NaN values. - # (note that if there are NaN values in d1, this will return - # False, which is why we do not bother with testing `other.has_nan` - # here). - return (self.sum == other.sum) and np.all(d0 == d1) - - return (self.sum == other.sum) and torch.all( - d0 == d1 - ) # With pytorch there is no `has_nan` attribute + if self.has_nan: + other.no_nan # Ensure has_nan is computed. + return ( + other.has_nan + and self.sum == other.sum + and (self.no_nan.mask == other.no_nan.mask).all() + and + # Note that the second test below (==) may crash e.g. for + # a single scalar NaN value, so we do not run it when all + # values are missing. + (self.no_nan.mask.all() or (self.no_nan == other.no_nan).all()) + ) + else: + # Simple case where we do not need to worry about NaN values. + # (note that if there are NaN values in d1, this will return + # False, which is why we do not bother with testing `other.has_nan` + # here). + return (self.sum == other.sum) and np.all(d0 == d1) def __ne__(self, other): return not self == other From 8c174dd01604a7d47b6732bc995753d6bf3b746c Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 29 May 2024 15:48:58 +0530 Subject: [PATCH 14/45] Add function to convert numpy arrays to pytorch tensors --- pytensor/link/basic.py | 6 +++++- pytensor/link/pytorch/linker.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index b8784067f3..7500814715 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -600,6 +600,10 @@ def create_thunk_inputs(self, storage_map: dict[Variable, list[Any]]) -> list[An def jit_compile(self, fn: Callable) -> Callable: """JIT compile a converted ``FunctionGraph``.""" + def input_filter(self, inp: Any) -> Any: + """Apply a filter to the data input.""" + return inp + def output_filter(self, var: Variable, out: Any) -> Any: """Apply a filter to the data output by a JITed function call.""" return out @@ -657,7 +661,7 @@ def thunk( thunk_inputs=thunk_inputs, thunk_outputs=thunk_outputs, ): - outputs = fgraph_jit(*[x[0] for x in thunk_inputs]) + outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs]) for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs): compute_map[o_var][0] = True diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 605c930d1b..d37267e5cb 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,3 +1,6 @@ +from typing import Any + +import torch from numpy.random import Generator, RandomState from pytensor.compile.sharedvalue import SharedVariable, shared @@ -7,6 +10,13 @@ class PytorchLinker(JITLinker): """A `Linker` that compiles NumPy-based operations using torch.compile.""" + def input_filter(self, inp: Any) -> Any: + from pytensor.link.pytorch.dispatch import pytorch_typify + + if isinstance(inp, torch.Tensor): + return inp + return pytorch_typify(inp) + def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.pytorch.dispatch import pytorch_funcify from pytensor.tensor.random.type import RandomType From 0977c3ae08a69877b79173a574d345eca9bfc393 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 29 May 2024 15:55:32 +0530 Subject: [PATCH 15/45] Avoid copy when converting to tensor --- pytensor/link/pytorch/dispatch/basic.py | 27 +++++++--------------- pytensor/link/pytorch/dispatch/elemwise.py | 1 - 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 09b33f7a25..dcdd3abfc3 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -16,14 +16,7 @@ def pytorch_typify(data, dtype=None, **kwargs): if dtype is None: return data else: - return torch.tensor(data, dtype=dtype) - - -@pytorch_typify.register(torch.Tensor) -def pytorch_typify_tensor(data, dtype=None, **kwargs): - # if len(data.shape) == 0: - # return data.item() - return torch.tensor(data, dtype=dtype) + return torch.as_tensor(data, dtype=dtype) @singledispatch @@ -74,17 +67,13 @@ def assert_fn(x, *conditions): def pytorch_safe_copy(x): - try: - res = x.clone() - except NotImplementedError: - # warnings.warn( - # "`jnp.copy` is not implemented yet. Using the object's `copy` method." - # ) - if hasattr(x, "copy"): - res = torch.tensor(x.copy()) - else: - warnings.warn(f"Object has no `copy` method: {x}") - res = x + # Cannot use try-except due to: https://github.com/pytorch/pytorch/issues/93720 + + if hasattr(x, "clone"): + res = torch.clone(x) + else: + warnings.warn(f"Object has no `clone` method: {x}") + res = x return res diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 9ba186cfa5..dccf4d040b 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -11,7 +11,6 @@ def pytorch_funcify_Elemwise(op, node, **kwargs): base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) def elemwise_fn(*inputs): - # Elemwise._check_runtime_broadcast(node, tuple(map(torch.tensor, inputs))) return base_fn(*inputs) return elemwise_fn From 1c23825c9bc02a8ab66ffdf75ee01dcf52584df1 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 29 May 2024 19:38:25 +0530 Subject: [PATCH 16/45] Fix tests --- pytensor/link/pytorch/dispatch/basic.py | 2 +- tests/link/pytorch/test_basic.py | 27 +++++++++++++------------ 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index dcdd3abfc3..c2cbe80492 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -14,7 +14,7 @@ def pytorch_typify(data, dtype=None, **kwargs): r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" if dtype is None: - return data + return torch.tensor(data) else: return torch.as_tensor(data, dtype=dtype) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index e8d358507a..cea4b534e8 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -1,6 +1,7 @@ from collections.abc import Callable, Iterable from functools import partial +import numpy as np import pytest from pytensor.compile.function import function @@ -40,14 +41,14 @@ def compare_pytorch_and_py( Numerical inputs for testing the function graph assert_fn: func, opt Assert function used to check for equality between python and pytorch. If not - provided uses torch.testing.assert_close + provided uses np.testing.assert_allclose must_be_device_array: Bool Checks if torch.device.type is cuda """ if assert_fn is None: - assert_fn = partial(torch.testing.assert_close) + assert_fn = partial(np.testing.assert_allclose) fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] @@ -107,42 +108,42 @@ def func(*args, op=op): out_fg = FunctionGraph([x, y], outs, clone=False) assert len(out_fg.outputs) == 2 - out_jx = pytorch_funcify(out_fg) + out_torch = pytorch_funcify(out_fg) x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX)) y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX)) - res = out_jx(x_val, y_val) + res = out_torch(x_val, y_val) assert len(res) == 2 assert op1.called == 1 assert op2.called == 1 - res = out_jx(x_val, y_val) + res = out_torch(x_val, y_val) assert len(res) == 2 assert op1.called == 2 assert op2.called == 2 def test_shared(): - a = shared(torch.tensor([1, 2, 3], dtype=getattr(torch, config.floatX))) + a = shared(np.array([1, 2, 3], dtype=config.floatX)) pytensor_torch_fn = function([], a, mode="PYTORCH") pytorch_res = pytensor_torch_fn() assert isinstance(pytorch_res, torch.Tensor) - torch.testing.assert_close(pytorch_res, a.get_value()) + np.testing.assert_allclose(pytorch_res, a.get_value()) pytensor_torch_fn = function([], a * 2, mode="PYTORCH") pytorch_res = pytensor_torch_fn() assert isinstance(pytorch_res, torch.Tensor) - torch.testing.assert_close(pytorch_res, a.get_value() * 2) + np.testing.assert_allclose(pytorch_res, a.get_value() * 2) - new_a_value = torch.tensor([3, 4, 5], dtype=getattr(torch, config.floatX)) + new_a_value = np.array([3, 4, 5], dtype=config.floatX) a.set_value(new_a_value) pytorch_res = pytensor_torch_fn() assert isinstance(pytorch_res, torch.Tensor) - torch.testing.assert_close(pytorch_res, new_a_value * 2) + np.testing.assert_allclose(pytorch_res, new_a_value * 2) def test_shared_updates(): @@ -162,10 +163,10 @@ def test_shared_updates(): def test_pytorch_ifelse(): - true_vals = torch.tensor([1, 2, 3]) - false_vals = torch.tensor([-1, -2, -3]) + true_vals = np.r_[1, 2, 3] + false_vals = np.r_[-1, -2, -3] - x = ifelse(torch.tensor(True), true_vals, false_vals) + x = ifelse(np.array(True), true_vals, false_vals) x_fg = FunctionGraph([], [x]) compare_pytorch_and_py(x_fg, []) From c9195a86f48abf6280870eed6fe7b057fa46b266 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 31 May 2024 22:08:32 +0530 Subject: [PATCH 17/45] Remove dispatches that are not tested --- pytensor/link/pytorch/dispatch/basic.py | 10 +- pytensor/link/pytorch/dispatch/elemwise.py | 32 ---- pytensor/link/pytorch/dispatch/scalar.py | 201 --------------------- pytensor/link/pytorch/linker.py | 3 +- pytensor/tensor/basic.py | 2 - pytensor/tensor/sharedvar.py | 2 + 6 files changed, 5 insertions(+), 245 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index c2cbe80492..42d4c500c9 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -3,7 +3,7 @@ import torch -from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse from pytensor.link.utils import fgraph_to_python @@ -84,11 +84,3 @@ def deepcopyop(x): return pytorch_safe_copy(x) return deepcopyop - - -@pytorch_funcify.register(ViewOp) -def pytorch_funcify_ViewOp(op, **kwargs): - def viewop(x): - return x - - return viewop diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index dccf4d040b..255a69c933 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -2,7 +2,6 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @pytorch_funcify.register(Elemwise) @@ -34,34 +33,3 @@ def dimshuffle(x): return res return dimshuffle - - -@pytorch_funcify.register(Softmax) -def pytorch_funcify_Softmax(op, **kwargs): - axis = op.axis - - def softmax(x): - return torch.nn.functional.softmax(x, dim=axis) - - return softmax - - -@pytorch_funcify.register(SoftmaxGrad) -def pytorch_funcify_SoftmaxGrad(op, **kwargs): - axis = op.axis - - def softmax_grad(dy, sm): - dy_times_sm = dy * sm - return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdims=True) * sm - - return softmax_grad - - -@pytorch_funcify.register(LogSoftmax) -def pytorch_funcify_LogSoftmax(op, **kwargs): - axis = op.axis - - def log_softmax(x): - return torch.nn.functional.log_softmax(x, dim=axis) - - return log_softmax diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 293839ae02..56ec438c9f 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -1,62 +1,11 @@ -import typing -from collections.abc import Callable - import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify -from pytensor.scalar import Softplus from pytensor.scalar.basic import ( - Cast, - Clip, - Identity, ScalarOp, - Second, -) -from pytensor.scalar.math import ( - BetaIncInv, - Erf, - Erfc, - Erfcinv, - Erfcx, - Erfinv, - GammaIncCInv, - GammaIncInv, - Iv, - Ive, - Log1mexp, - Psi, - TriGamma, ) -def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: str | None = None) -> Callable: - try: - import tensorflow_probability.substrates.jax.math as tfp_jax_math - except ModuleNotFoundError: - raise NotImplementedError( - f"No JAX implementation for Op {op.name}. " - "Implementation is available if TensorFlow Probability is installed" - ) - - if jax_op_name is None: - jax_op_name = op.name - return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name)) - - -def all_inputs_are_scalar(node): - """Check whether all the inputs of an `Elemwise` are scalar values.""" - ndims_input = [inp.type.ndim for inp in node.inputs] - are_inputs_scalars = True - for ndim in ndims_input: - try: - if ndim > 0: - are_inputs_scalars = False - except TypeError: - are_inputs_scalars = False - - return are_inputs_scalars - - @pytorch_funcify.register(ScalarOp) def pytorch_funcify_ScalarOp(op, node, **kwargs): """Return pytorch function that implements the same computation as the Scalar Op. @@ -89,153 +38,3 @@ def pytorch_func(*args): ) return pytorch_func - - -@pytorch_funcify.register(Cast) -def pytorch_funcify_Cast(op, **kwargs): - def cast(x): - return torch.tensor(x).astype(op.o_type.dtype) - - return cast - - -@pytorch_funcify.register(Identity) -def pytorch_funcify_Identity(op, **kwargs): - def identity(x): - return x - - return identity - - -@pytorch_funcify.register(Clip) -def pytorch_funcify_Clip(op, **kwargs): - """Register the translation for the `Clip` `Op`.""" - - def clip(x, min, max): - return torch.clip(x, min, max) - - return clip - - -# @pytorch_funcify.register(Composite) -# def pytorch_funcify_Composite(op, node, vectorize=True, **kwargs): -# jax_impl = pytorch_funcify(op.fgraph) - -# if len(node.outputs) == 1: - -# def composite(*args): -# return jax_impl(*args)[0] - -# else: - -# def composite(*args): -# return jax_impl(*args) - -# return jnp.vectorize(composite) - - -@pytorch_funcify.register(Second) -def pytorch_funcify_Second(op, **kwargs): - def second(x, y): - _, y = torch.broadcast_tensors(x, y) - return y - - return second - - -@pytorch_funcify.register(GammaIncInv) -def pytorch_funcify_GammaIncInv(op, **kwargs): - gammaincinv = try_import_tfp_jax_op(op, jax_op_name="igammainv") - - return gammaincinv - - -@pytorch_funcify.register(GammaIncCInv) -def pytorch_funcify_GammaIncCInv(op, **kwargs): - gammainccinv = try_import_tfp_jax_op(op, jax_op_name="igammacinv") - - return gammainccinv - - -@pytorch_funcify.register(Erf) -def pytorch_funcify_Erf(op, node, **kwargs): - def erf(x): - return torch.special.erf(x) - - return erf - - -@pytorch_funcify.register(Erfc) -def pytorch_funcify_Erfc(op, **kwargs): - def erfc(x): - return torch.special.erfc(x) - - return erfc - - -@pytorch_funcify.register(Erfinv) -def pytorch_funcify_Erfinv(op, **kwargs): - def erfinv(x): - return torch.special.erfinv(x) - - return erfinv - - -@pytorch_funcify.register(BetaIncInv) -@pytorch_funcify.register(Erfcx) -@pytorch_funcify.register(Erfcinv) -def pytorch_funcify_from_tfp(op, **kwargs): - tfp_jax_op = try_import_tfp_jax_op(op) - - return tfp_jax_op - - -@pytorch_funcify.register(Iv) -def pytorch_funcify_Iv(op, **kwargs): - ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive") - - def iv(v, x): - return ive(v, x) / torch.exp(-torch.abs(torch.real(x))) - - return iv - - -@pytorch_funcify.register(Ive) -def pytorch_funcify_Ive(op, **kwargs): - ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive") - - return ive - - -@pytorch_funcify.register(Log1mexp) -def pytorch_funcify_Log1mexp(op, node, **kwargs): - def log1mexp(x): - return torch.where( - x < torch.log(0.5), torch.log1p(-torch.exp(x)), torch.log(-torch.expm1(x)) - ) - - return log1mexp - - -@pytorch_funcify.register(Psi) -def pytorch_funcify_Psi(op, node, **kwargs): - def psi(x): - return torch.special.digamma(x) - - return psi - - -@pytorch_funcify.register(TriGamma) -def pytorch_funcify_TriGamma(op, node, **kwargs): - def tri_gamma(x): - return torch.special.polygamma(1, x) - - return tri_gamma - - -@pytorch_funcify.register(Softplus) -def pytorch_funcify_Softplus(op, **kwargs): - def softplus(x): - return torch.nn.functional.softplus(x) - - return softplus diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d37267e5cb..eb18a621b3 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,6 +1,5 @@ from typing import Any -import torch from numpy.random import Generator, RandomState from pytensor.compile.sharedvalue import SharedVariable, shared @@ -11,6 +10,8 @@ class PytorchLinker(JITLinker): """A `Linker` that compiles NumPy-based operations using torch.compile.""" def input_filter(self, inp: Any) -> Any: + import torch + from pytensor.link.pytorch.dispatch import pytorch_typify if isinstance(inp, torch.Tensor): diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index dacb52a63a..493d5e8aaf 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -14,7 +14,6 @@ from typing import cast as type_cast import numpy as np -import torch from numpy.core.multiarray import normalize_axis_index from numpy.core.numeric import normalize_axis_tuple @@ -186,7 +185,6 @@ def _as_tensor_numbers(x, name, ndim, dtype=None, **kwargs): return constant(x, name=name, ndim=ndim, dtype=dtype) -@_as_tensor_variable.register(torch.Tensor) def _as_tensor_numbers(x, name, ndim, dtype=None, **kwargs): return constant(x, name=name, ndim=ndim, dtype=dtype) diff --git a/pytensor/tensor/sharedvar.py b/pytensor/tensor/sharedvar.py index bb8a62cf2f..dad1751f9b 100644 --- a/pytensor/tensor/sharedvar.py +++ b/pytensor/tensor/sharedvar.py @@ -131,6 +131,8 @@ def scalar_constructor( value = _asarray(value, dtype=dtype) tensor_type = TensorType(dtype=str(value.dtype), shape=()) + # Do not pass the dtype to asarray because we want this to fail if + # strict is True and the types do not match. rval = TensorSharedVariable( type=tensor_type, value=np.array(value, copy=True), From b07805cd9151b22d7417ec68579ee4d441c2b29e Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 31 May 2024 22:33:00 +0530 Subject: [PATCH 18/45] set path for pytorch tests --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3c5a07b244..8885f07a4f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -121,7 +121,7 @@ jobs: python-version: "3.10" fast-compile: 0 float32: 0 - # part: "tests/link/pytorch" + part: "tests/link/pytorch" steps: - uses: actions/checkout@v4 with: From 9e8d3fc8fe947e4105058aa4acf52cb4393c0666 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Tue, 4 Jun 2024 23:56:07 +0530 Subject: [PATCH 19/45] Remove tensorflow probability from yml --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8885f07a4f..0135d9d9ce 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -149,7 +149,7 @@ jobs: mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi - if [[ $INSTALL_TORCH == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch && pip install tensorflow-probability; fi + if [[ $INSTALL_TORCH == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch; fi pip install -e ./ mamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' From a2d3afab5c49a544d47615671b1ba88668ee512f Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 5 Jun 2024 00:08:58 +0530 Subject: [PATCH 20/45] Add checks for runtime broadcasting --- pytensor/link/pytorch/dispatch/__init__.py | 12 ------------ pytensor/link/pytorch/dispatch/elemwise.py | 1 + 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index e68ca37f84..b6af171995 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -3,17 +3,5 @@ # # Load dispatch specializations import pytensor.link.pytorch.dispatch.scalar - -# import pytensor.link.jax.dispatch.tensor_basic -# import pytensor.link.jax.dispatch.subtensor -# import pytensor.link.jax.dispatch.shape -# import pytensor.link.jax.dispatch.extra_ops -# import pytensor.link.jax.dispatch.nlinalg -# import pytensor.link.jax.dispatch.slinalg -# import pytensor.link.jax.dispatch.random import pytensor.link.pytorch.dispatch.elemwise -# import pytensor.link.jax.dispatch.scan -# import pytensor.link.jax.dispatch.sparse -# import pytensor.link.jax.dispatch.blockwise - # isort: on diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 255a69c933..406eea18ca 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -10,6 +10,7 @@ def pytorch_funcify_Elemwise(op, node, **kwargs): base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) def elemwise_fn(*inputs): + Elemwise._check_runtime_broadcast(node, inputs) return base_fn(*inputs) return elemwise_fn From a577a80a2533ff742871aebdae4d30376981bc88 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 5 Jun 2024 00:19:24 +0530 Subject: [PATCH 21/45] Remove IfElse --- pytensor/link/pytorch/dispatch/basic.py | 18 +----------------- tests/link/pytorch/test_basic.py | 22 ++-------------------- 2 files changed, 3 insertions(+), 37 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 42d4c500c9..f38344367c 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -5,7 +5,6 @@ from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph -from pytensor.ifelse import IfElse from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise @@ -14,7 +13,7 @@ def pytorch_typify(data, dtype=None, **kwargs): r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" if dtype is None: - return torch.tensor(data) + return torch.as_tensor(data, dtype=None) else: return torch.as_tensor(data, dtype=dtype) @@ -41,21 +40,6 @@ def pytorch_funcify_FunctionGraph( ) -@pytorch_funcify.register(IfElse) -def pytorch_funcify_IfElse(op, **kwargs): - n_outs = op.n_outs - - def ifelse(cond, *args, n_outs=n_outs): - res = torch.where( - cond, - args[:n_outs][0], - args[n_outs:][0], - ) - return res - - return ifelse - - @pytorch_funcify.register(CheckAndRaise) def pytorch_funcify_CheckAndRaise(op, **kwargs): def assert_fn(x, *conditions): diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index cea4b534e8..73c01c590a 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -10,10 +10,9 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import Op, get_test_value -from pytensor.ifelse import ifelse +from pytensor.graph.op import Op from pytensor.raise_op import assert_op -from pytensor.tensor.type import dscalar, scalar, vector +from pytensor.tensor.type import scalar, vector torch = pytest.importorskip("torch") @@ -162,23 +161,6 @@ def test_shared_updates(): assert a.get_value() == 7 -def test_pytorch_ifelse(): - true_vals = np.r_[1, 2, 3] - false_vals = np.r_[-1, -2, -3] - - x = ifelse(np.array(True), true_vals, false_vals) - x_fg = FunctionGraph([], [x]) - - compare_pytorch_and_py(x_fg, []) - - a = dscalar("a") - a.tag.test_value = 0.2 - x = ifelse(a < 0.5, true_vals, false_vals) - x_fg = FunctionGraph([a], [x]) - - compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs]) - - def test_pytorch_checkandraise(): p = scalar() p.tag.test_value = 0 From 499a174e3dda9167781bebcef3e7e8886803d576 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 12 Jun 2024 10:49:15 +0530 Subject: [PATCH 22/45] Remove dev notebook --- pytorch_sandbox_example.ipynb | 213 ---------------------------------- 1 file changed, 213 deletions(-) delete mode 100644 pytorch_sandbox_example.ipynb diff --git a/pytorch_sandbox_example.ipynb b/pytorch_sandbox_example.ipynb deleted file mode 100644 index 011c69f15d..0000000000 --- a/pytorch_sandbox_example.ipynb +++ /dev/null @@ -1,213 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Log [id A] 2\n", - " └─ Sub [id B] 1\n", - " ├─ ExpandDims{axis=0} [id C] 0\n", - " │ └─ 1 [id D]\n", - " └─ x [id E]\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "import torch\n", - "import pytensor\n", - "import pytensor.tensor as pt\n", - "import numpy as np\n", - "\n", - "from pytensor.graph.fg import FunctionGraph\n", - "from pytensor.link.jax.dispatch import jax_funcify\n", - "from pytensor.link.pytorch.dispatch import pytorch_funcify\n", - "from pytensor.compile.mode import get_mode\n", - "\n", - "from pytensor.graph.rewriting.utils import rewrite_graph\n", - "\n", - "x = pt.vector(\"x\")\n", - "one_mx = 1 - x\n", - "out = pt.log(one_mx)\n", - "\n", - "fg = FunctionGraph(inputs=None, outputs=[out])\n", - "\n", - "pytensor.dprint(fg)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Log1p [id A] 1\n", - " └─ Neg [id B] 0\n", - " └─ x [id C]\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "opt_fg = rewrite_graph(fg, include=(\"canonicalize\", \"stabilize\", \"specialize\"))\n", - "pytensor.dprint(opt_fg)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "pytorch_fn = pytorch_funcify(opt_fg)\n", - "jax_fn = jax_funcify(opt_fg)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "JAX output = [-2.30258509 -2.30258509]\n", - "Pytorch output = tensor([-2.3026, -2.3026], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/harshvir/miniconda3/envs/pytensor-dev/lib/python3.11/site-packages/torch/_dynamo/utils.py:1764: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " return node.target(*args, **kwargs)\n", - "/home/harshvir/miniconda3/envs/pytensor-dev/lib/python3.11/site-packages/torch/fx/interpreter.py:274: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " return target(*args, **kwargs)\n", - "/home/harshvir/miniconda3/envs/pytensor-dev/lib/python3.11/site-packages/torch/fx/interpreter.py:274: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " return target(*args, **kwargs)\n" - ] - } - ], - "source": [ - "pytorch_compiled_fn = torch.compile(pytorch_fn)\n", - "pytorch_out = pytorch_compiled_fn(torch.tensor([0.9, 0.9]).cuda())[0]\n", - "\n", - "jax_compiled_fn = jax.jit(jax_fn)\n", - "jax_out = jax_compiled_fn(np.array([0.9, 0.9]))[0]\n", - "\n", - "print(f'JAX output = {jax_out}')\n", - "print(f'Pytorch output = {pytorch_out}')" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[0;31mSignature:\u001b[0m \u001b[0mpytorch_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mDocstring:\u001b[0m \n", - "\u001b[0;31mSource:\u001b[0m \n", - "\u001b[0;32mdef\u001b[0m \u001b[0mpytorch_funcified_fgraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;31m# Neg(x)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mtensor_variable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;31m# Log1p(Neg.0)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mtensor_variable_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise_fn_1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_variable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtensor_variable_1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mFile:\u001b[0m /tmp/tmpa33e4b_h\n", - "\u001b[0;31mType:\u001b[0m function" - ] - } - ], - "source": [ - "??pytorch_fn" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[0;31mSignature:\u001b[0m \u001b[0mjax_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mDocstring:\u001b[0m \n", - "\u001b[0;31mSource:\u001b[0m \n", - "\u001b[0;32mdef\u001b[0m \u001b[0mjax_funcified_fgraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;31m# Neg(x)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mtensor_variable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;31m# Log1p(Neg.0)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mtensor_variable_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise_fn_1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_variable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtensor_variable_1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mFile:\u001b[0m /tmp/tmpldlcl44u\n", - "\u001b[0;31mType:\u001b[0m function" - ] - } - ], - "source": [ - "??jax_fn" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytensor-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 2826613b199272e79663cd27b126664dbf824e31 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 12 Jun 2024 10:54:25 +0530 Subject: [PATCH 23/45] Fix check and raise --- pytensor/link/pytorch/dispatch/basic.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index f38344367c..d94b82e4b2 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -12,10 +12,8 @@ @singledispatch def pytorch_typify(data, dtype=None, **kwargs): r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" - if dtype is None: - return torch.as_tensor(data, dtype=None) - else: - return torch.as_tensor(data, dtype=dtype) + device = "cuda" if torch.cuda.is_available() else "cpu" + return torch.as_tensor(data, dtype=dtype, device=device) @singledispatch @@ -42,9 +40,13 @@ def pytorch_funcify_FunctionGraph( @pytorch_funcify.register(CheckAndRaise) def pytorch_funcify_CheckAndRaise(op, **kwargs): + error = op.exc_type + msg = op.msg + def assert_fn(x, *conditions): for cond in conditions: - assert cond.item() + if not cond.item(): + raise error(msg) return x return assert_fn From 62ffceca58e5c04ea18b524bf2a09d563472c1b7 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 12 Jun 2024 11:13:00 +0530 Subject: [PATCH 24/45] Fix compare_pytorch_and_py --- tests/link/pytorch/test_basic.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 73c01c590a..d07790f0ef 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -65,9 +65,9 @@ def compare_pytorch_and_py( if len(fgraph.outputs) > 1: for j, p in zip(pytorch_res, py_res): - assert_fn(j, p) + assert_fn(j.cpu(), p) else: - assert_fn(pytorch_res, py_res) + assert_fn([pytorch_res[0].cpu()], py_res) return pytensor_torch_fn, pytorch_res @@ -129,22 +129,23 @@ def test_shared(): pytorch_res = pytensor_torch_fn() assert isinstance(pytorch_res, torch.Tensor) - np.testing.assert_allclose(pytorch_res, a.get_value()) + np.testing.assert_allclose(pytorch_res.cpu(), a.get_value()) pytensor_torch_fn = function([], a * 2, mode="PYTORCH") pytorch_res = pytensor_torch_fn() assert isinstance(pytorch_res, torch.Tensor) - np.testing.assert_allclose(pytorch_res, a.get_value() * 2) + np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2) new_a_value = np.array([3, 4, 5], dtype=config.floatX) a.set_value(new_a_value) pytorch_res = pytensor_torch_fn() assert isinstance(pytorch_res, torch.Tensor) - np.testing.assert_allclose(pytorch_res, new_a_value * 2) + np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2) +@pytest.mark.xfail(reason="Shared variables will be handled in later PRs") def test_shared_updates(): a = shared(0) @@ -168,8 +169,3 @@ def test_pytorch_checkandraise(): res = assert_op(p, p < 1.0) function((p,), res, mode=pytorch_mode) - - -def set_test_value(x, v): - x.tag.test_value = v - return x From acdbba1491fb8cea9fdac92b91351239ab5110d5 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 12 Jun 2024 11:34:42 +0530 Subject: [PATCH 25/45] Fix DimShuffle --- pytensor/link/pytorch/dispatch/elemwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 406eea18ca..f39e108bed 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -19,7 +19,7 @@ def elemwise_fn(*inputs): @pytorch_funcify.register(DimShuffle) def pytorch_funcify_DimShuffle(op, **kwargs): def dimshuffle(x): - res = torch.transpose(x, *op.transposition) + res = torch.permute(x, op.transposition) shape = list(res.shape[: len(op.shuffle)]) From 2519c659fcb53cfbca8b9638c61d6be44f981b53 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 12 Jun 2024 12:58:56 +0530 Subject: [PATCH 26/45] Add tests for Elemwise operations --- tests/link/pytorch/test_elemwise.py | 55 +++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/link/pytorch/test_elemwise.py diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py new file mode 100644 index 0000000000..1d843b8051 --- /dev/null +++ b/tests/link/pytorch/test_elemwise.py @@ -0,0 +1,55 @@ +import numpy as np + +import pytensor.tensor as pt +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor import elemwise as pt_elemwise +from pytensor.tensor.type import matrix, tensor, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_Dimshuffle(): + a_pt = matrix("a") + + x = a_pt.T + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + + x = a_pt.dimshuffle([0, 1, "x"]) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + + a_pt = tensor(dtype=config.floatX, shape=(None, 1)) + x = a_pt.dimshuffle((0,)) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + + a_pt = tensor(dtype=config.floatX, shape=(None, 1)) + x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + + +def test_multiple_input_output(): + x = vector("x") + y = vector("y") + out = pt.mul(x, y) + + fg = FunctionGraph(outputs=[out], clone=False) + compare_pytorch_and_py(fg, [[1.5], [2.5]]) + + x = vector("x") + y = vector("y") + div = pt.int_div(x, y) + pt_sum = pt.add(y, x) + + fg = FunctionGraph(outputs=[div, pt_sum], clone=False) + compare_pytorch_and_py(fg, [[1.5], [2.5]]) + + +def test_pytorch_elemwise(): + x = pt.vector("x") + out = pt.log(1 - x) + + fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(fg, [[0.9, 0.9]]) From eb6d5c2adac7b711cb5814ac78ac04fbbd62f84e Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 14 Jun 2024 18:06:54 +0530 Subject: [PATCH 27/45] Fix test for CheckAndRaise --- tests/link/pytorch/test_basic.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index d07790f0ef..d0a1e52a25 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -11,7 +11,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op -from pytensor.raise_op import assert_op +from pytensor.raise_op import CheckAndRaise from pytensor.tensor.type import scalar, vector @@ -163,9 +163,13 @@ def test_shared_updates(): def test_pytorch_checkandraise(): - p = scalar() - p.tag.test_value = 0 + check_and_raise = CheckAndRaise(AssertionError, "testing") - res = assert_op(p, p < 1.0) + x = scalar("x") + conds = (x > 0, x > 3) + y = check_and_raise(x, *conds) - function((p,), res, mode=pytorch_mode) + y_fn = function([x], y, mode="PYTORCH") + + with pytest.raises(AssertionError, match="testing"): + y_fn(0.0) From 9f02a4fba48188df7d669a7eae193d4cf38f7ded Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 14 Jun 2024 18:14:40 +0530 Subject: [PATCH 28/45] Remove duplicate function --- pytensor/tensor/basic.py | 4 ---- pytensor/tensor/type.py | 1 - pytensor/tensor/variable.py | 1 - 3 files changed, 6 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 493d5e8aaf..518b55da99 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -185,10 +185,6 @@ def _as_tensor_numbers(x, name, ndim, dtype=None, **kwargs): return constant(x, name=name, ndim=ndim, dtype=dtype) -def _as_tensor_numbers(x, name, ndim, dtype=None, **kwargs): - return constant(x, name=name, ndim=ndim, dtype=dtype) - - @_as_tensor_variable.register(bool) def _as_tensor_bool(x, name, ndim, **kwargs): raise TypeError( diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index cea1f24216..b55d226471 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -160,7 +160,6 @@ def filter(self, data, strict=False, allow_downcast=None): # however, casting it would defeat the purpose of not # loading the whole data into memory pass - elif isinstance(data, np.ndarray) and (data.dtype == self.numpy_dtype): if data.dtype.num != self.numpy_dtype.num: data = _asarray(data, dtype=self.dtype) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 00fad89c7f..e881331017 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -959,7 +959,6 @@ def __eq__(self, other): self.no_nan # Ensure has_nan is computed. # Note that in the comparisons below, the elementwise comparisons # come last because they are the most expensive checks. - if self.has_nan: other.no_nan # Ensure has_nan is computed. return ( From caf2965bde9149d28d51a3a731932872c943578b Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 15 Jun 2024 22:00:29 +0530 Subject: [PATCH 29/45] Remove device from pytorch_typify --- pytensor/link/pytorch/dispatch/basic.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index d94b82e4b2..53c2c32a7e 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -12,8 +12,7 @@ @singledispatch def pytorch_typify(data, dtype=None, **kwargs): r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" - device = "cuda" if torch.cuda.is_available() else "cpu" - return torch.as_tensor(data, dtype=dtype, device=device) + return torch.as_tensor(data, dtype=dtype) @singledispatch From c603c6b76635754605f922f1ee94fadafca14397 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 15 Jun 2024 22:56:57 +0530 Subject: [PATCH 30/45] Use micromamba for pytorch install --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8991492fc8..80485a19bd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -152,7 +152,7 @@ jobs: micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi - if [[ $INSTALL_TORCH == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch; fi + if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch; fi pip install -e ./ micromamba list && pip freeze From 3f17107b652ee358e0dccc083823b9ab2ecbeb93 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sun, 16 Jun 2024 15:25:18 +0530 Subject: [PATCH 31/45] Fix pytorch linker --- pytensor/link/pytorch/linker.py | 68 ++------------------------------- 1 file changed, 4 insertions(+), 64 deletions(-) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index eb18a621b3..035d654c83 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,8 +1,6 @@ from typing import Any -from numpy.random import Generator, RandomState - -from pytensor.compile.sharedvalue import SharedVariable, shared +from pytensor.graph.basic import Variable from pytensor.link.basic import JITLinker @@ -10,66 +8,15 @@ class PytorchLinker(JITLinker): """A `Linker` that compiles NumPy-based operations using torch.compile.""" def input_filter(self, inp: Any) -> Any: - import torch - from pytensor.link.pytorch.dispatch import pytorch_typify - if isinstance(inp, torch.Tensor): - return inp return pytorch_typify(inp) + def output_filter(self, var: Variable, out: Any) -> Any: + return out.cpu() + def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.pytorch.dispatch import pytorch_funcify - from pytensor.tensor.random.type import RandomType - - shared_rng_inputs = [ - inp - for inp in fgraph.inputs - if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType)) - ] - - # Replace any shared RNG inputs so that their values can be updated in place - # without affecting the original RNG container. This is necessary because - # JAX does not accept RandomState/Generators as inputs, and they will have to - # be tipyfied - if shared_rng_inputs: - # warnings.warn( - # f"The RandomType SharedVariables {shared_rng_inputs} will not be used " - # f"in the compiled JAX graph. Instead a copy will be used.", - # UserWarning, - # ) - new_shared_rng_inputs = [ - shared(inp.get_value(borrow=False)) for inp in shared_rng_inputs - ] - - fgraph.replace_all( - zip(shared_rng_inputs, new_shared_rng_inputs), - import_missing=True, - reason="PytorchLinker.fgraph_convert", - ) - - for old_inp, new_inp in zip(shared_rng_inputs, new_shared_rng_inputs): - new_inp_storage = [new_inp.get_value(borrow=True)] - storage_map[new_inp] = new_inp_storage - old_inp_storage = storage_map.pop(old_inp) - # Find index of old_inp_storage in input_storage - for input_storage_idx, input_storage_item in enumerate(input_storage): - # We have to establish equality based on identity because input_storage may contain numpy arrays - if input_storage_item is old_inp_storage: - break - else: # no break - raise ValueError() - input_storage[input_storage_idx] = new_inp_storage - # We need to change the order of the inputs of the FunctionGraph - # so that the new input is in the same position as to old one, - # to align with the storage_map. We hope this is safe! - old_inp_fgrap_index = fgraph.inputs.index(old_inp) - fgraph.remove_input( - old_inp_fgrap_index, - reason="PytorchLinker.fgraph_convert", - ) - fgraph.inputs.remove(new_inp) - fgraph.inputs.insert(old_inp_fgrap_index, new_inp) return pytorch_funcify( fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs @@ -81,16 +28,9 @@ def jit_compile(self, fn): return torch.compile(fn) def create_thunk_inputs(self, storage_map): - from pytensor.link.pytorch.dispatch import pytorch_typify - thunk_inputs = [] for n in self.fgraph.inputs: sinput = storage_map[n] - if isinstance(sinput[0], RandomState | Generator): - new_value = pytorch_typify( - sinput[0], dtype=getattr(sinput[0], "dtype", None) - ) - sinput[0] = new_value thunk_inputs.append(sinput) return thunk_inputs From e850d8dea921140d00925dd8dc93dc7ab4c4ec09 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sun, 16 Jun 2024 15:27:38 +0530 Subject: [PATCH 32/45] Fix typify and deepcopy --- pytensor/link/pytorch/dispatch/basic.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 53c2c32a7e..c74df67b5b 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -1,4 +1,3 @@ -import warnings from functools import singledispatch import torch @@ -18,7 +17,9 @@ def pytorch_typify(data, dtype=None, **kwargs): @singledispatch def pytorch_funcify(op, node=None, storage_map=None, **kwargs): """Create a PyTorch compatible function from an PyTensor `Op`.""" - raise NotImplementedError(f"No PyTorch conversion for the given `Op`: {op}") + raise NotImplementedError( + f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation" + ) @pytorch_funcify.register(FunctionGraph) @@ -51,21 +52,9 @@ def assert_fn(x, *conditions): return assert_fn -def pytorch_safe_copy(x): - # Cannot use try-except due to: https://github.com/pytorch/pytorch/issues/93720 - - if hasattr(x, "clone"): - res = torch.clone(x) - else: - warnings.warn(f"Object has no `clone` method: {x}") - res = x - - return res - - @pytorch_funcify.register(DeepCopyOp) def pytorch_funcify_DeepCopyOp(op, **kwargs): def deepcopyop(x): - return pytorch_safe_copy(x) + return x.clone() return deepcopyop From e682fc441267f3d59a07b573eb59d13cac098005 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sun, 16 Jun 2024 16:00:13 +0530 Subject: [PATCH 33/45] Parametrize device in all tests --- tests/link/pytorch/test_basic.py | 152 +++++++++++++++------------- tests/link/pytorch/test_elemwise.py | 82 +++++++++------ 2 files changed, 131 insertions(+), 103 deletions(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index d0a1e52a25..d18855fef1 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -72,104 +72,116 @@ def compare_pytorch_and_py( return pytensor_torch_fn, pytorch_res -def test_pytorch_FunctionGraph_once(): +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_pytorch_FunctionGraph_once(device): """Make sure that an output is only computed once when it's referenced multiple times.""" from pytensor.link.pytorch.dispatch import pytorch_funcify - x = vector("x") - y = vector("y") + with torch.device(device): + x = vector("x") + y = vector("y") - class TestOp(Op): - def __init__(self): - self.called = 0 + class TestOp(Op): + def __init__(self): + self.called = 0 - def make_node(self, *args): - return Apply(self, list(args), [x.type() for x in args]) + def make_node(self, *args): + return Apply(self, list(args), [x.type() for x in args]) - def perform(self, inputs, outputs): - for i, inp in enumerate(inputs): - outputs[i][0] = inp[0] + def perform(self, inputs, outputs): + for i, inp in enumerate(inputs): + outputs[i][0] = inp[0] - @pytorch_funcify.register(TestOp) - def pytorch_funcify_TestOp(op, **kwargs): - def func(*args, op=op): - op.called += 1 - return list(args) + @pytorch_funcify.register(TestOp) + def pytorch_funcify_TestOp(op, **kwargs): + def func(*args, op=op): + op.called += 1 + return list(args) - return func + return func - op1 = TestOp() - op2 = TestOp() + op1 = TestOp() + op2 = TestOp() - q, r = op1(x, y) - outs = op2(q + r, q + r) + q, r = op1(x, y) + outs = op2(q + r, q + r) - out_fg = FunctionGraph([x, y], outs, clone=False) - assert len(out_fg.outputs) == 2 + out_fg = FunctionGraph([x, y], outs, clone=False) + assert len(out_fg.outputs) == 2 - out_torch = pytorch_funcify(out_fg) + out_torch = pytorch_funcify(out_fg) - x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX)) - y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX)) + x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX)) + y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX)) - res = out_torch(x_val, y_val) - assert len(res) == 2 - assert op1.called == 1 - assert op2.called == 1 + res = out_torch(x_val, y_val) + assert len(res) == 2 + assert op1.called == 1 + assert op2.called == 1 - res = out_torch(x_val, y_val) - assert len(res) == 2 - assert op1.called == 2 - assert op2.called == 2 + res = out_torch(x_val, y_val) + assert len(res) == 2 + assert op1.called == 2 + assert op2.called == 2 -def test_shared(): - a = shared(np.array([1, 2, 3], dtype=config.floatX)) - pytensor_torch_fn = function([], a, mode="PYTORCH") - pytorch_res = pytensor_torch_fn() +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_shared(device): + with torch.device(device): + a = shared(np.array([1, 2, 3], dtype=config.floatX)) + pytensor_torch_fn = function([], a, mode="PYTORCH") + pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) - np.testing.assert_allclose(pytorch_res.cpu(), a.get_value()) + assert isinstance(pytorch_res, torch.Tensor) + assert isinstance(a.get_value(), np.ndarray) + np.testing.assert_allclose(pytorch_res.cpu(), a.get_value()) - pytensor_torch_fn = function([], a * 2, mode="PYTORCH") - pytorch_res = pytensor_torch_fn() + pytensor_torch_fn = function([], a * 2, mode="PYTORCH") + pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) - np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2) + assert isinstance(pytorch_res, torch.Tensor) + assert isinstance(a.get_value(), np.ndarray) + np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2) - new_a_value = np.array([3, 4, 5], dtype=config.floatX) - a.set_value(new_a_value) + new_a_value = np.array([3, 4, 5], dtype=config.floatX) + a.set_value(new_a_value) - pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) - np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2) + pytorch_res = pytensor_torch_fn() + assert isinstance(pytorch_res, torch.Tensor) + np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2) -@pytest.mark.xfail(reason="Shared variables will be handled in later PRs") -def test_shared_updates(): - a = shared(0) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_shared_updates(device): + with torch.device(device): + a = shared(0) - pytensor_torch_fn = function([], a, updates={a: a + 1}, mode="PYTORCH") - res1, res2 = pytensor_torch_fn(), pytensor_torch_fn() - assert res1 == 0 - assert res2 == 1 - assert a.get_value() == 2 + pytensor_torch_fn = function([], a, updates={a: a + 1}, mode="PYTORCH") + res1, res2 = pytensor_torch_fn(), pytensor_torch_fn() + assert res1 == 0 + assert res2 == 1 + assert a.get_value() == 2 + assert isinstance(a.get_value(), np.ndarray) - a.set_value(5) - res1, res2 = pytensor_torch_fn(), pytensor_torch_fn() - assert res1 == 5 - assert res2 == 6 - assert a.get_value() == 7 + a.set_value(5) + res1, res2 = pytensor_torch_fn(), pytensor_torch_fn() + assert res1 == 5 + assert res2 == 6 + assert a.get_value() == 7 + assert isinstance(a.get_value(), np.ndarray) -def test_pytorch_checkandraise(): - check_and_raise = CheckAndRaise(AssertionError, "testing") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_pytorch_checkandraise(device): + with torch.device(device): + check_and_raise = CheckAndRaise(AssertionError, "testing") - x = scalar("x") - conds = (x > 0, x > 3) - y = check_and_raise(x, *conds) + x = scalar("x") + conds = (x > 0, x > 3) + y = check_and_raise(x, *conds) - y_fn = function([x], y, mode="PYTORCH") + y_fn = function([x], y, mode="PYTORCH") - with pytest.raises(AssertionError, match="testing"): - y_fn(0.0) + with pytest.raises(AssertionError, match="testing"): + y_fn(0.0) + assert y_fn(4).item() == 4 diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 1d843b8051..eb5d5f7be7 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -1,4 +1,6 @@ import numpy as np +import pytest +import torch import pytensor.tensor as pt from pytensor.configdefaults import config @@ -8,48 +10,62 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py -def test_pytorch_Dimshuffle(): - a_pt = matrix("a") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_pytorch_Dimshuffle(device): + with torch.device(device): + a_pt = matrix("a") - x = a_pt.T - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + x = a_pt.T + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py( + x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)] + ) - x = a_pt.dimshuffle([0, 1, "x"]) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + x = a_pt.dimshuffle([0, 1, "x"]) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py( + x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)] + ) - a_pt = tensor(dtype=config.floatX, shape=(None, 1)) - x = a_pt.dimshuffle((0,)) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + a_pt = tensor(dtype=config.floatX, shape=(None, 1)) + x = a_pt.dimshuffle((0,)) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py( + x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)] + ) - a_pt = tensor(dtype=config.floatX, shape=(None, 1)) - x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + a_pt = tensor(dtype=config.floatX, shape=(None, 1)) + x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py( + x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)] + ) -def test_multiple_input_output(): - x = vector("x") - y = vector("y") - out = pt.mul(x, y) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_multiple_input_output(device): + with torch.device(device): + x = vector("x") + y = vector("y") + out = pt.mul(x, y) - fg = FunctionGraph(outputs=[out], clone=False) - compare_pytorch_and_py(fg, [[1.5], [2.5]]) + fg = FunctionGraph(outputs=[out], clone=False) + compare_pytorch_and_py(fg, [[1.5], [2.5]]) - x = vector("x") - y = vector("y") - div = pt.int_div(x, y) - pt_sum = pt.add(y, x) + x = vector("x") + y = vector("y") + div = pt.int_div(x, y) + pt_sum = pt.add(y, x) - fg = FunctionGraph(outputs=[div, pt_sum], clone=False) - compare_pytorch_and_py(fg, [[1.5], [2.5]]) + fg = FunctionGraph(outputs=[div, pt_sum], clone=False) + compare_pytorch_and_py(fg, [[1.5], [2.5]]) -def test_pytorch_elemwise(): - x = pt.vector("x") - out = pt.log(1 - x) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_pytorch_elemwise(device): + with torch.device(device): + x = pt.vector("x") + out = pt.log(1 - x) - fg = FunctionGraph([x], [out]) - compare_pytorch_and_py(fg, [[0.9, 0.9]]) + fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(fg, [[0.9, 0.9]]) From bf4cf92984cd5f01cce1e97b97509d0449587e4f Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sun, 16 Jun 2024 16:46:46 +0530 Subject: [PATCH 34/45] Install torch with cuda --- .github/workflows/test.yml | 4 ++-- tests/link/pytorch/test_elemwise.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 80485a19bd..023519c268 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -152,7 +152,7 @@ jobs: micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi - if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch; fi + if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi pip install -e ./ micromamba list && pip freeze @@ -209,7 +209,7 @@ jobs: - name: Install dependencies shell: micromamba-shell {0} run: | - micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytorch pytest-benchmark + micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia pip install -e ./ micromamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index eb5d5f7be7..d831974d2e 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -1,6 +1,5 @@ import numpy as np import pytest -import torch import pytensor.tensor as pt from pytensor.configdefaults import config @@ -10,6 +9,9 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py +torch = pytest.importorskip("torch") + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_pytorch_Dimshuffle(device): with torch.device(device): From 899e7f930d66ebc947a67e93be5ed6ad1c479e49 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sun, 16 Jun 2024 16:48:22 +0530 Subject: [PATCH 35/45] Fix test_pytorch_FunctionGraph_once --- tests/link/pytorch/test_basic.py | 68 ++++++++++++++++---------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index d18855fef1..22237e342f 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -72,57 +72,55 @@ def compare_pytorch_and_py( return pytensor_torch_fn, pytorch_res -@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_pytorch_FunctionGraph_once(device): """Make sure that an output is only computed once when it's referenced multiple times.""" from pytensor.link.pytorch.dispatch import pytorch_funcify - with torch.device(device): - x = vector("x") - y = vector("y") + x = vector("x") + y = vector("y") - class TestOp(Op): - def __init__(self): - self.called = 0 + class TestOp(Op): + def __init__(self): + self.called = 0 - def make_node(self, *args): - return Apply(self, list(args), [x.type() for x in args]) + def make_node(self, *args): + return Apply(self, list(args), [x.type() for x in args]) - def perform(self, inputs, outputs): - for i, inp in enumerate(inputs): - outputs[i][0] = inp[0] + def perform(self, inputs, outputs): + for i, inp in enumerate(inputs): + outputs[i][0] = inp[0] - @pytorch_funcify.register(TestOp) - def pytorch_funcify_TestOp(op, **kwargs): - def func(*args, op=op): - op.called += 1 - return list(args) + @pytorch_funcify.register(TestOp) + def pytorch_funcify_TestOp(op, **kwargs): + def func(*args, op=op): + op.called += 1 + return list(args) - return func + return func - op1 = TestOp() - op2 = TestOp() + op1 = TestOp() + op2 = TestOp() - q, r = op1(x, y) - outs = op2(q + r, q + r) + q, r = op1(x, y) + outs = op2(q + r, q + r) - out_fg = FunctionGraph([x, y], outs, clone=False) - assert len(out_fg.outputs) == 2 + out_fg = FunctionGraph([x, y], outs, clone=False) + assert len(out_fg.outputs) == 2 - out_torch = pytorch_funcify(out_fg) + out_torch = pytorch_funcify(out_fg) - x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX)) - y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX)) + x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX)) + y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX)) - res = out_torch(x_val, y_val) - assert len(res) == 2 - assert op1.called == 1 - assert op2.called == 1 + res = out_torch(x_val, y_val) + assert len(res) == 2 + assert op1.called == 1 + assert op2.called == 1 - res = out_torch(x_val, y_val) - assert len(res) == 2 - assert op1.called == 2 - assert op2.called == 2 + res = out_torch(x_val, y_val) + assert len(res) == 2 + assert op1.called == 2 + assert op2.called == 2 @pytest.mark.parametrize("device", ["cpu", "cuda"]) From 04d293564bba050d49bd4e8498e9bb729c1c39aa Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sun, 16 Jun 2024 16:50:47 +0530 Subject: [PATCH 36/45] Remove device argument from test --- tests/link/pytorch/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 22237e342f..9d57df666d 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -72,7 +72,7 @@ def compare_pytorch_and_py( return pytensor_torch_fn, pytorch_res -def test_pytorch_FunctionGraph_once(device): +def test_pytorch_FunctionGraph_once(): """Make sure that an output is only computed once when it's referenced multiple times.""" from pytensor.link.pytorch.dispatch import pytorch_funcify From 8ec76617923d45999e3aa88c853b89bcb92f9486 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Mon, 17 Jun 2024 17:10:19 +0530 Subject: [PATCH 37/45] remove device from elemwise tests and add assertions --- tests/link/pytorch/test_basic.py | 104 ++++++++++++++++------------ tests/link/pytorch/test_elemwise.py | 84 +++++++++------------- 2 files changed, 92 insertions(+), 96 deletions(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 9d57df666d..3d8773eed0 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -72,55 +72,71 @@ def compare_pytorch_and_py( return pytensor_torch_fn, pytorch_res -def test_pytorch_FunctionGraph_once(): +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_pytorch_FunctionGraph_once(device): """Make sure that an output is only computed once when it's referenced multiple times.""" from pytensor.link.pytorch.dispatch import pytorch_funcify - x = vector("x") - y = vector("y") + with torch.device(device): + x = vector("x") + y = vector("y") + + class TestOp(Op): + def __init__(self): + self.called = 0 + + def make_node(self, *args): + return Apply(self, list(args), [x.type() for x in args]) - class TestOp(Op): - def __init__(self): - self.called = 0 + def perform(self, inputs, outputs): + for i, inp in enumerate(inputs): + outputs[i][0] = inp[0] - def make_node(self, *args): - return Apply(self, list(args), [x.type() for x in args]) + @pytorch_funcify.register(TestOp) + def pytorch_funcify_TestOp(op, **kwargs): + def func(*args, op=op): + op.called += 1 + for arg in args: + assert arg.device.type == device + return list(args) - def perform(self, inputs, outputs): - for i, inp in enumerate(inputs): - outputs[i][0] = inp[0] + return func - @pytorch_funcify.register(TestOp) - def pytorch_funcify_TestOp(op, **kwargs): - def func(*args, op=op): - op.called += 1 - return list(args) + op1 = TestOp() + op2 = TestOp() - return func + q, r = op1(x, y) + outs = op2(q + r, q + r) - op1 = TestOp() - op2 = TestOp() + out_fg = FunctionGraph([x, y], outs, clone=False) + assert len(out_fg.outputs) == 2 - q, r = op1(x, y) - outs = op2(q + r, q + r) + out_torch = pytorch_funcify(out_fg) - out_fg = FunctionGraph([x, y], outs, clone=False) - assert len(out_fg.outputs) == 2 + x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX)) + y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX)) - out_torch = pytorch_funcify(out_fg) + res = out_torch(x_val, y_val) - x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX)) - y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX)) + for output in res: + assert torch.equal( + output, torch.tensor([3, 5]).to(getattr(torch, config.floatX)) + ) - res = out_torch(x_val, y_val) - assert len(res) == 2 - assert op1.called == 1 - assert op2.called == 1 + assert len(res) == 2 + assert op1.called == 1 + assert op2.called == 1 - res = out_torch(x_val, y_val) - assert len(res) == 2 - assert op1.called == 2 - assert op2.called == 2 + res = out_torch(x_val, y_val) + + for output in res: + assert torch.equal( + output, torch.tensor([3, 5]).to(getattr(torch, config.floatX)) + ) + + assert len(res) == 2 + assert op1.called == 2 + assert op2.called == 2 @pytest.mark.parametrize("device", ["cpu", "cuda"]) @@ -169,17 +185,15 @@ def test_shared_updates(device): assert isinstance(a.get_value(), np.ndarray) -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_pytorch_checkandraise(device): - with torch.device(device): - check_and_raise = CheckAndRaise(AssertionError, "testing") +def test_pytorch_checkandraise(): + check_and_raise = CheckAndRaise(AssertionError, "testing") - x = scalar("x") - conds = (x > 0, x > 3) - y = check_and_raise(x, *conds) + x = scalar("x") + conds = (x > 0, x > 3) + y = check_and_raise(x, *conds) - y_fn = function([x], y, mode="PYTORCH") + y_fn = function([x], y, mode="PYTORCH") - with pytest.raises(AssertionError, match="testing"): - y_fn(0.0) - assert y_fn(4).item() == 4 + with pytest.raises(AssertionError, match="testing"): + y_fn(0.0) + assert y_fn(4).item() == 4 diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index d831974d2e..1d843b8051 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -1,5 +1,4 @@ import numpy as np -import pytest import pytensor.tensor as pt from pytensor.configdefaults import config @@ -9,65 +8,48 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py -torch = pytest.importorskip("torch") +def test_pytorch_Dimshuffle(): + a_pt = matrix("a") + x = a_pt.T + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_pytorch_Dimshuffle(device): - with torch.device(device): - a_pt = matrix("a") + x = a_pt.dimshuffle([0, 1, "x"]) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) - x = a_pt.T - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py( - x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)] - ) + a_pt = tensor(dtype=config.floatX, shape=(None, 1)) + x = a_pt.dimshuffle((0,)) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) - x = a_pt.dimshuffle([0, 1, "x"]) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py( - x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)] - ) + a_pt = tensor(dtype=config.floatX, shape=(None, 1)) + x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt) + x_fg = FunctionGraph([a_pt], [x]) + compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) - a_pt = tensor(dtype=config.floatX, shape=(None, 1)) - x = a_pt.dimshuffle((0,)) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py( - x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)] - ) - a_pt = tensor(dtype=config.floatX, shape=(None, 1)) - x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py( - x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)] - ) +def test_multiple_input_output(): + x = vector("x") + y = vector("y") + out = pt.mul(x, y) + fg = FunctionGraph(outputs=[out], clone=False) + compare_pytorch_and_py(fg, [[1.5], [2.5]]) -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_multiple_input_output(device): - with torch.device(device): - x = vector("x") - y = vector("y") - out = pt.mul(x, y) + x = vector("x") + y = vector("y") + div = pt.int_div(x, y) + pt_sum = pt.add(y, x) - fg = FunctionGraph(outputs=[out], clone=False) - compare_pytorch_and_py(fg, [[1.5], [2.5]]) + fg = FunctionGraph(outputs=[div, pt_sum], clone=False) + compare_pytorch_and_py(fg, [[1.5], [2.5]]) - x = vector("x") - y = vector("y") - div = pt.int_div(x, y) - pt_sum = pt.add(y, x) - fg = FunctionGraph(outputs=[div, pt_sum], clone=False) - compare_pytorch_and_py(fg, [[1.5], [2.5]]) +def test_pytorch_elemwise(): + x = pt.vector("x") + out = pt.log(1 - x) - -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_pytorch_elemwise(device): - with torch.device(device): - x = pt.vector("x") - out = pt.log(1 - x) - - fg = FunctionGraph([x], [out]) - compare_pytorch_and_py(fg, [[0.9, 0.9]]) + fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(fg, [[0.9, 0.9]]) From bb7df4161e529f4e97f6bf7551e119f42086b6ba Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Tue, 18 Jun 2024 00:48:10 +0530 Subject: [PATCH 38/45] skip tests if cuda is not available --- tests/link/pytorch/test_basic.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 3d8773eed0..e1367896f0 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -74,6 +74,8 @@ def compare_pytorch_and_py( @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_pytorch_FunctionGraph_once(device): + if torch.cuda.is_available() is False: + pytest.skip("CUDA is not available") """Make sure that an output is only computed once when it's referenced multiple times.""" from pytensor.link.pytorch.dispatch import pytorch_funcify @@ -141,6 +143,8 @@ def func(*args, op=op): @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_shared(device): + if torch.cuda.is_available() is False: + pytest.skip("CUDA is not available") with torch.device(device): a = shared(np.array([1, 2, 3], dtype=config.floatX)) pytensor_torch_fn = function([], a, mode="PYTORCH") @@ -167,6 +171,8 @@ def test_shared(device): @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_shared_updates(device): + if torch.cuda.is_available() is False: + pytest.skip("CUDA is not available") with torch.device(device): a = shared(0) From 0441cf20d866f67b6dbdf3291b17b9cbb8c1b3f2 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Tue, 18 Jun 2024 11:16:32 +0530 Subject: [PATCH 39/45] Fix tests --- tests/link/pytorch/test_basic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index e1367896f0..68d937fce8 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -74,7 +74,7 @@ def compare_pytorch_and_py( @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_pytorch_FunctionGraph_once(device): - if torch.cuda.is_available() is False: + if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA is not available") """Make sure that an output is only computed once when it's referenced multiple times.""" from pytensor.link.pytorch.dispatch import pytorch_funcify @@ -143,7 +143,7 @@ def func(*args, op=op): @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_shared(device): - if torch.cuda.is_available() is False: + if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA is not available") with torch.device(device): a = shared(np.array([1, 2, 3], dtype=config.floatX)) @@ -171,7 +171,7 @@ def test_shared(device): @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_shared_updates(device): - if torch.cuda.is_available() is False: + if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA is not available") with torch.device(device): a = shared(0) From 4ca5acab2b5556c2536fac7d3ddf757ca9b19838 Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Sun, 23 Jun 2024 23:03:36 +0100 Subject: [PATCH 40/45] Implemented softmax ops for PyTorch --- environment.yml | 2 +- pytensor/link/pytorch/dispatch/elemwise.py | 16 ++++++++++++++++ tests/link/pytorch/test_elemwise.py | 16 ++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 6be8e376da..9208ccce00 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ channels: dependencies: - python>=3.10 - compilers - - numpy>=1.17.0 + - numpy>=1.17.0,<2 - scipy>=0.14 - filelock - etuples diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index f39e108bed..7416c75f16 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -2,6 +2,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.special import Softmax @pytorch_funcify.register(Elemwise) @@ -34,3 +35,18 @@ def dimshuffle(x): return res return dimshuffle + + +@pytorch_funcify.register(Softmax) +def pytorch_funcify_Softmax(op, **kwargs): + axis = op.axis + + if axis is None: + raise TypeError( + "Implicit dimension choice for softmax has been deprecated in Pytorch, specify an axis." + ) + + def softmax(x): + return torch.nn.functional.softmax(x, dim=axis) + + return softmax diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 1d843b8051..4a376adf4c 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -1,9 +1,11 @@ import numpy as np +import pytest import pytensor.tensor as pt from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.tensor import elemwise as pt_elemwise +from pytensor.tensor.special import softmax from pytensor.tensor.type import matrix, tensor, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -53,3 +55,17 @@ def test_pytorch_elemwise(): fg = FunctionGraph([x], [out]) compare_pytorch_and_py(fg, [[0.9, 0.9]]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_softmax(axis): + x = matrix("x") + out = softmax(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) + + if axis is None: + with pytest.raises(TypeError): + compare_pytorch_and_py(fgraph, [test_input]) + else: + compare_pytorch_and_py(fgraph, [test_input]) From 287d9c275bbc5ea2e0cf9d3c1dfcae450e85b9bf Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Mon, 24 Jun 2024 21:56:40 +0100 Subject: [PATCH 41/45] Switched to run softmax on all items if axis is None --- pytensor/link/pytorch/dispatch/elemwise.py | 10 ++++------ tests/link/pytorch/test_elemwise.py | 6 +----- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 7416c75f16..39592a8378 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -41,12 +41,10 @@ def dimshuffle(x): def pytorch_funcify_Softmax(op, **kwargs): axis = op.axis - if axis is None: - raise TypeError( - "Implicit dimension choice for softmax has been deprecated in Pytorch, specify an axis." - ) - def softmax(x): - return torch.nn.functional.softmax(x, dim=axis) + if axis is not None: + return torch.softmax(x, dim=axis) + else: + return torch.softmax(x.ravel(), dim=0).reshape(x.shape) return softmax diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 4a376adf4c..947f2bb4eb 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -64,8 +64,4 @@ def test_softmax(axis): fgraph = FunctionGraph([x], [out]) test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) - if axis is None: - with pytest.raises(TypeError): - compare_pytorch_and_py(fgraph, [test_input]) - else: - compare_pytorch_and_py(fgraph, [test_input]) + compare_pytorch_and_py(fgraph, [test_input]) From f42e2a0293676868621902b6c18d9e3176e734b1 Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Mon, 24 Jun 2024 23:44:28 +0100 Subject: [PATCH 42/45] Implemented log softmax --- pytensor/link/pytorch/dispatch/elemwise.py | 15 ++++++++++++++- tests/link/pytorch/test_elemwise.py | 12 +++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 39592a8378..f61ff4cea4 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -2,7 +2,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.special import Softmax +from pytensor.tensor.special import LogSoftmax, Softmax @pytorch_funcify.register(Elemwise) @@ -48,3 +48,16 @@ def softmax(x): return torch.softmax(x.ravel(), dim=0).reshape(x.shape) return softmax + + +@pytorch_funcify.register(LogSoftmax) +def pytorch_funcify_LogSoftmax(op, **kwargs): + axis = op.axis + + def log_softmax(x): + if axis is not None: + return torch.log_softmax(x, dim=axis) + else: + return torch.log_softmax(x.ravel(), dim=0).reshape(x.shape) + + return log_softmax diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 947f2bb4eb..e153674b8b 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -5,7 +5,7 @@ from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.tensor import elemwise as pt_elemwise -from pytensor.tensor.special import softmax +from pytensor.tensor.special import log_softmax, softmax from pytensor.tensor.type import matrix, tensor, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -65,3 +65,13 @@ def test_softmax(axis): test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) compare_pytorch_and_py(fgraph, [test_input]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_logsoftmax(axis): + x = matrix("x") + out = log_softmax(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) + + compare_pytorch_and_py(fgraph, [test_input]) From 35b17e07872d1011e407342a638673bf3d01843d Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Wed, 26 Jun 2024 00:21:56 +0100 Subject: [PATCH 43/45] Implemented softmaxgrad --- pytensor/link/pytorch/dispatch/elemwise.py | 19 ++++++++++++++++++- tests/link/pytorch/test_elemwise.py | 13 ++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index f61ff4cea4..5a76e26af9 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -2,7 +2,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.special import LogSoftmax, Softmax +from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @pytorch_funcify.register(Elemwise) @@ -42,6 +42,9 @@ def pytorch_funcify_Softmax(op, **kwargs): axis = op.axis def softmax(x): + if not torch.is_floating_point(x): + x = x.to(torch.float32) + if axis is not None: return torch.softmax(x, dim=axis) else: @@ -55,9 +58,23 @@ def pytorch_funcify_LogSoftmax(op, **kwargs): axis = op.axis def log_softmax(x): + if not torch.is_floating_point(x): + x = x.to(torch.float32) + if axis is not None: return torch.log_softmax(x, dim=axis) else: return torch.log_softmax(x.ravel(), dim=0).reshape(x.shape) return log_softmax + + +@pytorch_funcify.register(SoftmaxGrad) +def jax_funcify_SoftmaxGrad(op, **kwargs): + axis = op.axis + + def softmax_grad(dy, sm): + dy_times_sm = dy * sm + return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm + + return softmax_grad diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index e153674b8b..0c4a29a676 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -5,7 +5,7 @@ from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.tensor import elemwise as pt_elemwise -from pytensor.tensor.special import log_softmax, softmax +from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.type import matrix, tensor, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -75,3 +75,14 @@ def test_logsoftmax(axis): test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) compare_pytorch_and_py(fgraph, [test_input]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_softmax_grad(axis): + dy = matrix("dy") + dy_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) + sm = matrix("sm") + sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + out = SoftmaxGrad(axis=axis)(dy, sm) + fgraph = FunctionGraph([dy, sm], [out]) + compare_pytorch_and_py(fgraph, [dy_value, sm_value]) From 5efc3c83e95f6aca07e46089b95c14965a532688 Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Thu, 27 Jun 2024 01:15:14 +0100 Subject: [PATCH 44/45] Added checks and error raises for nonfloat inputs --- pytensor/link/pytorch/dispatch/elemwise.py | 18 +++++++++----- tests/link/pytorch/test_elemwise.py | 28 +++++++++++++++++----- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 5a76e26af9..fb50933654 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -40,11 +40,14 @@ def dimshuffle(x): @pytorch_funcify.register(Softmax) def pytorch_funcify_Softmax(op, **kwargs): axis = op.axis + dtype = kwargs["node"].outputs[0].dtype - def softmax(x): - if not torch.is_floating_point(x): - x = x.to(torch.float32) + if not dtype.startswith("float"): + raise NotImplementedError( + "Pytorch Softmax is not currently implemented for non-float types." + ) + def softmax(x): if axis is not None: return torch.softmax(x, dim=axis) else: @@ -56,11 +59,14 @@ def softmax(x): @pytorch_funcify.register(LogSoftmax) def pytorch_funcify_LogSoftmax(op, **kwargs): axis = op.axis + dtype = kwargs["node"].outputs[0].dtype - def log_softmax(x): - if not torch.is_floating_point(x): - x = x.to(torch.float32) + if not dtype.startswith("float"): + raise NotImplementedError( + "Pytorch LogSoftmax is not currently implemented for non-float types." + ) + def log_softmax(x): if axis is not None: return torch.log_softmax(x, dim=axis) else: diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 0c4a29a676..586789772f 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -57,24 +57,40 @@ def test_pytorch_elemwise(): compare_pytorch_and_py(fg, [[0.9, 0.9]]) +@pytest.mark.parametrize("dtype", ["float64", "int64"]) @pytest.mark.parametrize("axis", [None, 0, 1]) -def test_softmax(axis): - x = matrix("x") +def test_softmax(axis, dtype): + x = matrix("x", dtype=dtype) out = softmax(x, axis=axis) fgraph = FunctionGraph([x], [out]) test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) - compare_pytorch_and_py(fgraph, [test_input]) + if dtype == "int64": + with pytest.raises( + NotImplementedError, + match="Pytorch Softmax is not currently implemented for non-float types.", + ): + compare_pytorch_and_py(fgraph, [test_input]) + else: + compare_pytorch_and_py(fgraph, [test_input]) +@pytest.mark.parametrize("dtype", ["float64", "int64"]) @pytest.mark.parametrize("axis", [None, 0, 1]) -def test_logsoftmax(axis): - x = matrix("x") +def test_logsoftmax(axis, dtype): + x = matrix("x", dtype=dtype) out = log_softmax(x, axis=axis) fgraph = FunctionGraph([x], [out]) test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) - compare_pytorch_and_py(fgraph, [test_input]) + if dtype == "int64": + with pytest.raises( + NotImplementedError, + match="Pytorch LogSoftmax is not currently implemented for non-float types.", + ): + compare_pytorch_and_py(fgraph, [test_input]) + else: + compare_pytorch_and_py(fgraph, [test_input]) @pytest.mark.parametrize("axis", [None, 0, 1]) From 16e415ae57437289517ec0b8263e26778957d56b Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Thu, 27 Jun 2024 01:15:14 +0100 Subject: [PATCH 45/45] Added checks and error raises for nonfloat inputs --- pytensor/link/pytorch/dispatch/elemwise.py | 18 +++++++++----- tests/link/pytorch/test_elemwise.py | 28 +++++++++++++++++----- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 5a76e26af9..0ddb25765f 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -40,11 +40,14 @@ def dimshuffle(x): @pytorch_funcify.register(Softmax) def pytorch_funcify_Softmax(op, **kwargs): axis = op.axis + dtype = kwargs["node"].inputs[0].dtype - def softmax(x): - if not torch.is_floating_point(x): - x = x.to(torch.float32) + if not dtype.startswith("float"): + raise NotImplementedError( + "Pytorch Softmax is not currently implemented for non-float types." + ) + def softmax(x): if axis is not None: return torch.softmax(x, dim=axis) else: @@ -56,11 +59,14 @@ def softmax(x): @pytorch_funcify.register(LogSoftmax) def pytorch_funcify_LogSoftmax(op, **kwargs): axis = op.axis + dtype = kwargs["node"].inputs[0].dtype - def log_softmax(x): - if not torch.is_floating_point(x): - x = x.to(torch.float32) + if not dtype.startswith("float"): + raise NotImplementedError( + "Pytorch LogSoftmax is not currently implemented for non-float types." + ) + def log_softmax(x): if axis is not None: return torch.log_softmax(x, dim=axis) else: diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 0c4a29a676..586789772f 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -57,24 +57,40 @@ def test_pytorch_elemwise(): compare_pytorch_and_py(fg, [[0.9, 0.9]]) +@pytest.mark.parametrize("dtype", ["float64", "int64"]) @pytest.mark.parametrize("axis", [None, 0, 1]) -def test_softmax(axis): - x = matrix("x") +def test_softmax(axis, dtype): + x = matrix("x", dtype=dtype) out = softmax(x, axis=axis) fgraph = FunctionGraph([x], [out]) test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) - compare_pytorch_and_py(fgraph, [test_input]) + if dtype == "int64": + with pytest.raises( + NotImplementedError, + match="Pytorch Softmax is not currently implemented for non-float types.", + ): + compare_pytorch_and_py(fgraph, [test_input]) + else: + compare_pytorch_and_py(fgraph, [test_input]) +@pytest.mark.parametrize("dtype", ["float64", "int64"]) @pytest.mark.parametrize("axis", [None, 0, 1]) -def test_logsoftmax(axis): - x = matrix("x") +def test_logsoftmax(axis, dtype): + x = matrix("x", dtype=dtype) out = log_softmax(x, axis=axis) fgraph = FunctionGraph([x], [out]) test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) - compare_pytorch_and_py(fgraph, [test_input]) + if dtype == "int64": + with pytest.raises( + NotImplementedError, + match="Pytorch LogSoftmax is not currently implemented for non-float types.", + ): + compare_pytorch_and_py(fgraph, [test_input]) + else: + compare_pytorch_and_py(fgraph, [test_input]) @pytest.mark.parametrize("axis", [None, 0, 1])