Skip to content

Commit ba4fcbe

Browse files
Ch0ronomatoIan SchweerricardoV94
authored
Implement OpFromGraph in PyTorch backend (#956)
Co-authored-by: Ian Schweer <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 3e55a20 commit ba4fcbe

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

pytensor/compile/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
OPT_O3,
3131
OPT_STABILIZE,
3232
OPT_UNSAFE,
33+
PYTORCH,
3334
AddDestroyHandler,
3435
AddFeatureOptimizer,
3536
Mode,

pytensor/link/pytorch/dispatch/basic.py

+16
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
import numpy as np
55
import torch
6+
import torch.compiler
67

8+
from pytensor.compile import PYTORCH
9+
from pytensor.compile.builders import OpFromGraph
710
from pytensor.compile.ops import DeepCopyOp
811
from pytensor.graph.fg import FunctionGraph
912
from pytensor.link.utils import fgraph_to_python
@@ -150,6 +153,19 @@ def makevector(*x):
150153
return makevector
151154

152155

156+
@pytorch_funcify.register(OpFromGraph)
157+
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
158+
kwargs.pop("storage_map", None)
159+
160+
# Apply inner rewrites
161+
PYTORCH.optimizer(op.fgraph)
162+
163+
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
164+
# Disable one step inlining to prevent torch from trying to import local functions
165+
# defined in `pytorch_funcify`
166+
return torch.compiler.disable(fgraph_fn, recursive=False)
167+
168+
153169
@pytorch_funcify.register(TensorFromScalar)
154170
def pytorch_funcify_TensorFromScalar(op, **kwargs):
155171
def tensorfromscalar(x):

tests/link/pytorch/test_basic.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
import pytensor.tensor.basic as ptb
8+
from pytensor.compile.builders import OpFromGraph
89
from pytensor.compile.function import function
910
from pytensor.compile.mode import get_mode
1011
from pytensor.compile.sharedvalue import SharedVariable, shared
@@ -14,7 +15,7 @@
1415
from pytensor.graph.op import Op
1516
from pytensor.raise_op import CheckAndRaise
1617
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
17-
from pytensor.tensor.type import matrix, scalar, vector
18+
from pytensor.tensor.type import matrices, matrix, scalar, vector
1819

1920

2021
torch = pytest.importorskip("torch")
@@ -301,3 +302,19 @@ def test_pytorch_MakeVector():
301302
x_fg = FunctionGraph([], [x])
302303

303304
compare_pytorch_and_py(x_fg, [])
305+
306+
307+
def test_pytorch_OpFromGraph():
308+
x, y, z = matrices("xyz")
309+
ofg_1 = OpFromGraph([x, y], [x + y])
310+
ofg_2 = OpFromGraph([x, y], [x * y, x - y])
311+
312+
o1, o2 = ofg_2(y, z)
313+
out = ofg_1(x, o1) + o2
314+
315+
xv = np.ones((2, 2), dtype=config.floatX)
316+
yv = np.ones((2, 2), dtype=config.floatX) * 3
317+
zv = np.ones((2, 2), dtype=config.floatX) * 5
318+
319+
f = FunctionGraph([x, y, z], [out])
320+
compare_pytorch_and_py(f, [xv, yv, zv])

0 commit comments

Comments
 (0)