Skip to content

Add initial support for PyTorch backend #764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
27e2526
Add pytorch support for some basic Ops
HarshvirSandhu May 13, 2024
629d00b
update variable names, docstrings
HarshvirSandhu May 13, 2024
3eceb56
Avoid numpy conversion of torch Tensors
HarshvirSandhu May 17, 2024
3cde964
Fix typify and CheckAndRaise
HarshvirSandhu May 17, 2024
c003aa5
Fix Elemwise Ops
HarshvirSandhu May 17, 2024
8dc406e
Fix Scalar Ops
HarshvirSandhu May 17, 2024
a8f6ddb
Fix ruff-format
HarshvirSandhu May 17, 2024
9d535f5
Initial setup for pytorch tests
HarshvirSandhu May 23, 2024
c5600da
Fix mode parameters for pytorch
HarshvirSandhu May 23, 2024
54b6248
Prevent conversion of scalars to numpy
HarshvirSandhu May 23, 2024
19454b3
Update TensorConstantSignature and map dtypes to Tensor types
HarshvirSandhu May 23, 2024
92d7114
Add tests for basic ops
HarshvirSandhu May 23, 2024
5aae0e5
Remove torch from user facing API
HarshvirSandhu May 29, 2024
8c174dd
Add function to convert numpy arrays to pytorch tensors
HarshvirSandhu May 29, 2024
0977c3a
Avoid copy when converting to tensor
HarshvirSandhu May 29, 2024
1c23825
Fix tests
HarshvirSandhu May 29, 2024
c9195a8
Remove dispatches that are not tested
HarshvirSandhu May 31, 2024
b07805c
set path for pytorch tests
HarshvirSandhu May 31, 2024
9e8d3fc
Remove tensorflow probability from yml
HarshvirSandhu Jun 4, 2024
a2d3afa
Add checks for runtime broadcasting
HarshvirSandhu Jun 4, 2024
a577a80
Remove IfElse
HarshvirSandhu Jun 4, 2024
499a174
Remove dev notebook
HarshvirSandhu Jun 12, 2024
2826613
Fix check and raise
HarshvirSandhu Jun 12, 2024
62ffcec
Fix compare_pytorch_and_py
HarshvirSandhu Jun 12, 2024
acdbba1
Fix DimShuffle
HarshvirSandhu Jun 12, 2024
2519c65
Add tests for Elemwise operations
HarshvirSandhu Jun 12, 2024
eb6d5c2
Fix test for CheckAndRaise
HarshvirSandhu Jun 14, 2024
9f02a4f
Remove duplicate function
HarshvirSandhu Jun 14, 2024
caf2965
Remove device from pytorch_typify
HarshvirSandhu Jun 15, 2024
bf87eb9
Merge branch 'main' of https://github.com/HarshvirSandhu/pytensor int…
HarshvirSandhu Jun 15, 2024
2c27683
Solve merge conflict
HarshvirSandhu Jun 15, 2024
c603c6b
Use micromamba for pytorch install
HarshvirSandhu Jun 15, 2024
3f17107
Fix pytorch linker
HarshvirSandhu Jun 16, 2024
e850d8d
Fix typify and deepcopy
HarshvirSandhu Jun 16, 2024
e682fc4
Parametrize device in all tests
HarshvirSandhu Jun 16, 2024
bf4cf92
Install torch with cuda
HarshvirSandhu Jun 16, 2024
899e7f9
Fix test_pytorch_FunctionGraph_once
HarshvirSandhu Jun 16, 2024
04d2935
Remove device argument from test
HarshvirSandhu Jun 16, 2024
8ec7661
remove device from elemwise tests and add assertions
HarshvirSandhu Jun 17, 2024
bb7df41
skip tests if cuda is not available
HarshvirSandhu Jun 17, 2024
0441cf2
Fix tests
HarshvirSandhu Jun 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ jobs:
float32: [0, 1]
install-numba: [0]
install-jax: [0]
install-torch: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
Expand Down Expand Up @@ -116,6 +117,11 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-torch: 1
python-version: "3.10"
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -142,9 +148,12 @@ jobs:
- name: Install dependencies
shell: micromamba-shell {0}
run: |

micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi

pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
Expand All @@ -153,6 +162,7 @@ jobs:
PYTHON_VERSION: ${{ matrix.python-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}

- name: Run tests
shell: micromamba-shell {0}
Expand Down Expand Up @@ -199,7 +209,7 @@ jobs:
- name: Install dependencies
shell: micromamba-shell {0}
run: |
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
Expand Down Expand Up @@ -268,3 +278,4 @@ jobs:
directory: ./coverage/
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}

15 changes: 15 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker


Expand All @@ -47,6 +48,7 @@
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
"pytorch": PytorchLinker(),
"numba": NumbaLinker(),
}

Expand Down Expand Up @@ -460,6 +462,18 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
],
),
)
PYTORCH = Mode(
PytorchLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
],
),
)
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
Expand All @@ -474,6 +488,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"FAST_RUN": FAST_RUN,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
}

instantiated_default_mode = None
Expand Down
6 changes: 5 additions & 1 deletion pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,10 @@ def create_thunk_inputs(self, storage_map: dict[Variable, list[Any]]) -> list[An
def jit_compile(self, fn: Callable) -> Callable:
"""JIT compile a converted ``FunctionGraph``."""

def input_filter(self, inp: Any) -> Any:
"""Apply a filter to the data input."""
return inp

def output_filter(self, var: Variable, out: Any) -> Any:
"""Apply a filter to the data output by a JITed function call."""
return out
Expand Down Expand Up @@ -657,7 +661,7 @@ def thunk(
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])

for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_var][0] = True
Expand Down
7 changes: 7 additions & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# isort: off
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify

# # Load dispatch specializations
import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise
# isort: on
60 changes: 60 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from functools import singledispatch

import torch

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise


@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
return torch.as_tensor(data, dtype=dtype)


@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
raise NotImplementedError(

Check warning on line 20 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L20

Added line #L20 was not covered by tests
f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation"
)


@pytorch_funcify.register(FunctionGraph)
def pytorch_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="pytorch_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
pytorch_funcify,
type_conversion_fn=pytorch_typify,
fgraph_name=fgraph_name,
**kwargs,
)


@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
error = op.exc_type
msg = op.msg

def assert_fn(x, *conditions):
for cond in conditions:
if not cond.item():
raise error(msg)
return x

return assert_fn


@pytorch_funcify.register(DeepCopyOp)
def pytorch_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return x.clone()

Check warning on line 58 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L58

Added line #L58 was not covered by tests

return deepcopyop
36 changes: 36 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise


@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)

def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)

return elemwise_fn


@pytorch_funcify.register(DimShuffle)
def pytorch_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = torch.permute(x, op.transposition)

Check warning on line 22 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L22

Added line #L22 was not covered by tests

shape = list(res.shape[: len(op.shuffle)])

Check warning on line 24 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L24

Added line #L24 was not covered by tests

for augm in op.augment:
shape.insert(augm, 1)

Check warning on line 27 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L27

Added line #L27 was not covered by tests

res = torch.reshape(res, shape)

Check warning on line 29 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L29

Added line #L29 was not covered by tests

if not op.inplace:
res = res.clone()

Check warning on line 32 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L32

Added line #L32 was not covered by tests

return res

Check warning on line 34 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L34

Added line #L34 was not covered by tests

return dimshuffle
40 changes: 40 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
ScalarOp,
)


@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.

This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
even though it's dispatched on the Scalar Op.
"""

nfunc_spec = getattr(op, "nfunc_spec", None)
if nfunc_spec is None:
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")

Check warning on line 19 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L19

Added line #L19 was not covered by tests

func_name = nfunc_spec[0]

pytorch_func = getattr(torch, func_name)

if len(node.inputs) > op.nfunc_spec[1]:
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
# even though the base Op from `func_name` is specified as a binary Op.
# This happens with `Add`, which can work as a `Sum` for multiple scalars.
pytorch_variadic_func = getattr(torch, op.nfunc_variadic, None)

Check warning on line 29 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L29

Added line #L29 was not covered by tests
if not pytorch_variadic_func:
raise NotImplementedError(

Check warning on line 31 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L31

Added line #L31 was not covered by tests
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs"
)

def pytorch_func(*args):
return pytorch_variadic_func(

Check warning on line 36 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L35-L36

Added lines #L35 - L36 were not covered by tests
torch.stack(torch.broadcast_tensors(*args), axis=0), axis=0
)

return pytorch_func
36 changes: 36 additions & 0 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any

from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker


class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""

def input_filter(self, inp: Any) -> Any:
from pytensor.link.pytorch.dispatch import pytorch_typify

return pytorch_typify(inp)

def output_filter(self, var: Variable, out: Any) -> Any:
return out.cpu()

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify

return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
)

def jit_compile(self, fn):
import torch

return torch.compile(fn)

def create_thunk_inputs(self, storage_map):
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
thunk_inputs.append(sinput)

return thunk_inputs
Empty file added tests/link/pytorch/__init__.py
Empty file.
Loading
Loading