Skip to content

Commit 6ad1c5c

Browse files
authored
Implement Dot and BatchedDot in PyTensor (#878)
1 parent 426931b commit 6ad1c5c

File tree

6 files changed

+84
-1
lines changed

6 files changed

+84
-1
lines changed

pytensor/link/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pytensor.link.pytorch.linker import PytorchLinker

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
33

44
# # Load dispatch specializations
5+
import pytensor.link.pytorch.dispatch.blas
56
import pytensor.link.pytorch.dispatch.scalar
67
import pytensor.link.pytorch.dispatch.elemwise
8+
import pytensor.link.pytorch.dispatch.math
79
import pytensor.link.pytorch.dispatch.extra_ops
8-
import pytensor.link.pytorch.dispatch.sort
910
import pytensor.link.pytorch.dispatch.shape
11+
import pytensor.link.pytorch.dispatch.sort
12+
1013
# isort: on
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch import pytorch_funcify
4+
from pytensor.tensor.blas import BatchedDot
5+
6+
7+
@pytorch_funcify.register(BatchedDot)
8+
def pytorch_funcify_BatchedDot(op, **kwargs):
9+
def batched_dot(a, b):
10+
if a.shape[0] != b.shape[0]:
11+
raise TypeError("Shapes must match in the 0-th dimension")
12+
return torch.bmm(a, b)
13+
14+
return batched_dot
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch import pytorch_funcify
4+
from pytensor.tensor.math import Dot
5+
6+
7+
@pytorch_funcify.register(Dot)
8+
def pytorch_funcify_Dot(op, **kwargs):
9+
def dot(x, y):
10+
return torch.matmul(x, y)
11+
12+
return dot

tests/link/pytorch/test_blas.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.configdefaults import config
5+
from pytensor.graph.fg import FunctionGraph
6+
from pytensor.tensor import blas as pt_blas
7+
from pytensor.tensor.type import tensor3
8+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
9+
10+
11+
def test_pytorch_BatchedDot():
12+
# tensor3 . tensor3
13+
a = tensor3("a")
14+
a_test = np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
15+
b = tensor3("b")
16+
b_test = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
17+
out = pt_blas.BatchedDot()(a, b)
18+
fgraph = FunctionGraph([a, b], [out])
19+
pytensor_pytorch_fn, _ = compare_pytorch_and_py(fgraph, [a_test, b_test])
20+
21+
# A dimension mismatch should raise a TypeError for compatibility
22+
inputs = [a_test[:-1], b_test]
23+
with pytest.raises(TypeError):
24+
pytensor_pytorch_fn(*inputs)

tests/link/pytorch/test_math.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
3+
from pytensor.configdefaults import config
4+
from pytensor.graph.fg import FunctionGraph
5+
from pytensor.tensor.type import matrix, scalar, vector
6+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
7+
8+
9+
def test_pytorch_dot():
10+
y = vector("y")
11+
y_test = np.r_[1.0, 2.0].astype(config.floatX)
12+
x = vector("x")
13+
x_test = np.r_[3.0, 4.0].astype(config.floatX)
14+
A = matrix("A")
15+
A_test = np.array([[6, 3], [3, 0]], dtype=config.floatX)
16+
alpha = scalar("alpha")
17+
alpha_test = np.array(3.0, dtype=config.floatX)
18+
beta = scalar("beta")
19+
beta_test = np.array(5.0, dtype=config.floatX)
20+
21+
# 2D * 2D
22+
out = A.dot(A * alpha) + beta * A
23+
fgraph = FunctionGraph([A, alpha, beta], [out])
24+
compare_pytorch_and_py(fgraph, [A_test, alpha_test, beta_test])
25+
26+
# 1D * 2D and 1D * 1D
27+
out = y.dot(alpha * A).dot(x) + beta * y
28+
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
29+
compare_pytorch_and_py(fgraph, [y_test, x_test, A_test, alpha_test, beta_test])

0 commit comments

Comments
 (0)