From ffad937718687d5fb4e35e45234b05c8d7c976d4 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Wed, 3 Jul 2024 15:40:48 +0800 Subject: [PATCH 01/10] Added PyTorch link and unit tests for normal dot --- pytensor/link/pytorch/dispatch/nlinalg.py | 12 +++++++++ tests/link/pytorch/test_nlinalg.py | 30 +++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 pytensor/link/pytorch/dispatch/nlinalg.py create mode 100644 tests/link/pytorch/test_nlinalg.py diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py new file mode 100644 index 0000000000..4275424f0a --- /dev/null +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -0,0 +1,12 @@ +import torch + +from pytensor.link.pytorch.dispatch import pytorch_funcify +from pytensor.tensor.math import Dot + + +@pytorch_funcify.register(Dot) +def pytorch_funcify_Dot(op, **kwargs): + def dot(x, y): + return torch.matmul(x, y) + + return dot diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py new file mode 100644 index 0000000000..9af527a249 --- /dev/null +++ b/tests/link/pytorch/test_nlinalg.py @@ -0,0 +1,30 @@ +import numpy as np + +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.tensor.type import matrix, scalar, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_tensor_basics(): + y = vector("y") + y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + x = vector("x") + x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) + A = matrix("A") + A.tag.test_value = np.array([[6, 3], [3, 0]], dtype=config.floatX) + alpha = scalar("alpha") + alpha.tag.test_value = np.array(3.0, dtype=config.floatX) + beta = scalar("beta") + beta.tag.test_value = np.array(5.0, dtype=config.floatX) + + # 1D * 2D * 1D + out = y.dot(alpha * A).dot(x) + beta * y + fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + # 2D * 2D + out = A.dot(A * alpha) + beta * A + fgraph = FunctionGraph([A, alpha, beta], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) From 5121a85715ea385c573764e1d1289de672995ca7 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sat, 6 Jul 2024 17:30:39 +0800 Subject: [PATCH 02/10] Changed implementation of dot. Renamed tests --- pytensor/link/pytorch/dispatch/nlinalg.py | 7 ++++++- tests/link/pytorch/test_nlinalg.py | 19 ++++++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py index 4275424f0a..f6243b26f3 100644 --- a/pytensor/link/pytorch/dispatch/nlinalg.py +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -7,6 +7,11 @@ @pytorch_funcify.register(Dot) def pytorch_funcify_Dot(op, **kwargs): def dot(x, y): - return torch.matmul(x, y) + # Case 1: Vector Product/Matrix Multiplication/1-D Broadcastable Vector + if x.shape < 3 and y.shape < 3: + return torch.matmul(x, y) + else: + # Case 2: Stackable batch dimension + return torch.tensordot(x, y, dims=([-1], [-2])) return dot diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index 9af527a249..35802434a0 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -3,11 +3,15 @@ from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value -from pytensor.tensor.type import matrix, scalar, vector +from pytensor.tensor.type import matrix, scalar, tensor3, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py -def test_tensor_basics(): +def test_pytorch_dot(): + a = tensor3("a") + a.tag.test_value = np.zeros((3, 2, 4)).astype(config.floatX) + b = tensor3("b") + b.tag.test_value = np.zeros((3, 4, 1)).astype(config.floatX) y = vector("y") y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) x = vector("x") @@ -19,12 +23,17 @@ def test_tensor_basics(): beta = scalar("beta") beta.tag.test_value = np.array(5.0, dtype=config.floatX) - # 1D * 2D * 1D - out = y.dot(alpha * A).dot(x) + beta * y - fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) + # 3D * 3D + out = a.dot(b * alpha) + beta * b + fgraph = FunctionGraph([a, b, alpha, beta], [out]) compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) # 2D * 2D out = A.dot(A * alpha) + beta * A fgraph = FunctionGraph([A, alpha, beta], [out]) compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + # 1D * 2D and 1D * 1D + out = y.dot(alpha * A).dot(x) + beta * y + fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) From 2721c5a21218b9dc0cbdddb400590de2f359b542 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sat, 6 Jul 2024 17:33:03 +0800 Subject: [PATCH 03/10] Changed dot implementation --- pytensor/link/pytorch/dispatch/nlinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py index f6243b26f3..267114d267 100644 --- a/pytensor/link/pytorch/dispatch/nlinalg.py +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -8,7 +8,7 @@ def pytorch_funcify_Dot(op, **kwargs): def dot(x, y): # Case 1: Vector Product/Matrix Multiplication/1-D Broadcastable Vector - if x.shape < 3 and y.shape < 3: + if x.shape == 1 or y.shape == 1 or (x.shape < 3 and y.shape < 3): return torch.matmul(x, y) else: # Case 2: Stackable batch dimension From 03bb3a80b54f3f9c6d832bb2806cec1fbab13648 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Mon, 8 Jul 2024 17:57:26 +0800 Subject: [PATCH 04/10] Reverted logic to correct scope for math.dot --- pytensor/link/pytorch/dispatch/nlinalg.py | 7 +------ tests/link/pytorch/test_nlinalg.py | 11 +---------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py index 267114d267..4275424f0a 100644 --- a/pytensor/link/pytorch/dispatch/nlinalg.py +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -7,11 +7,6 @@ @pytorch_funcify.register(Dot) def pytorch_funcify_Dot(op, **kwargs): def dot(x, y): - # Case 1: Vector Product/Matrix Multiplication/1-D Broadcastable Vector - if x.shape == 1 or y.shape == 1 or (x.shape < 3 and y.shape < 3): - return torch.matmul(x, y) - else: - # Case 2: Stackable batch dimension - return torch.tensordot(x, y, dims=([-1], [-2])) + return torch.matmul(x, y) return dot diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index 35802434a0..ece5471e65 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -3,15 +3,11 @@ from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value -from pytensor.tensor.type import matrix, scalar, tensor3, vector +from pytensor.tensor.type import matrix, scalar, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py def test_pytorch_dot(): - a = tensor3("a") - a.tag.test_value = np.zeros((3, 2, 4)).astype(config.floatX) - b = tensor3("b") - b.tag.test_value = np.zeros((3, 4, 1)).astype(config.floatX) y = vector("y") y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) x = vector("x") @@ -23,11 +19,6 @@ def test_pytorch_dot(): beta = scalar("beta") beta.tag.test_value = np.array(5.0, dtype=config.floatX) - # 3D * 3D - out = a.dot(b * alpha) + beta * b - fgraph = FunctionGraph([a, b, alpha, beta], [out]) - compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - # 2D * 2D out = A.dot(A * alpha) + beta * A fgraph = FunctionGraph([A, alpha, beta], [out]) From 2cf0ed2580cb2b1264ba562f3cd58494ef352925 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Mon, 15 Jul 2024 10:27:35 +0800 Subject: [PATCH 05/10] Reverted folder structure and added BatchedDot --- pytensor/link/pytorch/dispatch/__init__.py | 3 ++ pytensor/link/pytorch/dispatch/blas.py | 14 ++++++++ .../pytorch/dispatch/{nlinalg.py => math.py} | 0 tests/link/pytorch/test_blas.py | 36 +++++++++++++++++++ .../pytorch/{test_nlinalg.py => test_math.py} | 0 5 files changed, 53 insertions(+) create mode 100644 pytensor/link/pytorch/dispatch/blas.py rename pytensor/link/pytorch/dispatch/{nlinalg.py => math.py} (100%) create mode 100644 tests/link/pytorch/test_blas.py rename tests/link/pytorch/{test_nlinalg.py => test_math.py} (100%) diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index b6af171995..73ffc6aeae 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -2,6 +2,9 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify # # Load dispatch specializations +import pytensor.link.pytorch.dispatch.blas import pytensor.link.pytorch.dispatch.scalar import pytensor.link.pytorch.dispatch.elemwise +import pytensor.link.pytorch.dispatch.math + # isort: on diff --git a/pytensor/link/pytorch/dispatch/blas.py b/pytensor/link/pytorch/dispatch/blas.py new file mode 100644 index 0000000000..5691551998 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/blas.py @@ -0,0 +1,14 @@ +import torch + +from pytensor.link.pytorch.dispatch import pytorch_funcify +from pytensor.tensor.blas import BatchedDot + + +@pytorch_funcify.register(BatchedDot) +def pytorch_funcify_BatchedDot(op, **kwargs): + def batched_dot(a, b): + if a.shape[0] != b.shape[0]: + raise TypeError("Shapes must match in the 0-th dimension") + return torch.bmm(a, b) + + return batched_dot diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/math.py similarity index 100% rename from pytensor/link/pytorch/dispatch/nlinalg.py rename to pytensor/link/pytorch/dispatch/math.py diff --git a/tests/link/pytorch/test_blas.py b/tests/link/pytorch/test_blas.py new file mode 100644 index 0000000000..5bdabe2e60 --- /dev/null +++ b/tests/link/pytorch/test_blas.py @@ -0,0 +1,36 @@ +import numpy as np +import pytest + +from pytensor.compile.function import function +from pytensor.compile.mode import Mode +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.link.pytorch import PytorchLinker +from pytensor.tensor import blas as pt_blas +from pytensor.tensor.type import tensor3 +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_BatchedDot(): + # tensor3 . tensor3 + a = tensor3("a") + a.tag.test_value = ( + np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) + ) + b = tensor3("b") + b.tag.test_value = ( + np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) + ) + out = pt_blas.BatchedDot()(a, b) + fgraph = FunctionGraph([a, b], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + # A dimension mismatch should raise a TypeError for compatibility + inputs = [get_test_value(a)[:-1], get_test_value(b)] + opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) + pytorch_mode = Mode(PytorchLinker(), opts) + pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=pytorch_mode) + with pytest.raises(TypeError): + pytensor_jax_fn(*inputs) diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_math.py similarity index 100% rename from tests/link/pytorch/test_nlinalg.py rename to tests/link/pytorch/test_math.py From 307a3fb6a14174f73ad09119ac33f347aca14574 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Mon, 15 Jul 2024 10:30:35 +0800 Subject: [PATCH 06/10] Fixed minor typo in test naming --- tests/link/pytorch/test_blas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/link/pytorch/test_blas.py b/tests/link/pytorch/test_blas.py index 5bdabe2e60..63d32ff0f8 100644 --- a/tests/link/pytorch/test_blas.py +++ b/tests/link/pytorch/test_blas.py @@ -31,6 +31,6 @@ def test_pytorch_BatchedDot(): inputs = [get_test_value(a)[:-1], get_test_value(b)] opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) pytorch_mode = Mode(PytorchLinker(), opts) - pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=pytorch_mode) + pytensor_pytorch_fn = function(fgraph.inputs, fgraph.outputs, mode=pytorch_mode) with pytest.raises(TypeError): - pytensor_jax_fn(*inputs) + pytensor_pytorch_fn(*inputs) From 143a75a69f764f492075e7b069428a775c22cd53 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Mon, 15 Jul 2024 10:36:05 +0800 Subject: [PATCH 07/10] Fixed __init__.py file for tests to run --- pytensor/link/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytensor/link/__init__.py b/pytensor/link/__init__.py index e69de29bb2..c8c236a854 100644 --- a/pytensor/link/__init__.py +++ b/pytensor/link/__init__.py @@ -0,0 +1 @@ +from pytensor.link.pytorch.linker import PytorchLinker From 2d74b31298aeab765907d06e97a5486f8303d034 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Wed, 17 Jul 2024 10:27:40 +0800 Subject: [PATCH 08/10] Rewrite test to reuse pytorch function --- tests/link/pytorch/test_blas.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/link/pytorch/test_blas.py b/tests/link/pytorch/test_blas.py index 63d32ff0f8..ce64d53f08 100644 --- a/tests/link/pytorch/test_blas.py +++ b/tests/link/pytorch/test_blas.py @@ -1,13 +1,11 @@ import numpy as np import pytest -from pytensor.compile.function import function from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value -from pytensor.graph.rewriting.db import RewriteDatabaseQuery -from pytensor.link.pytorch import PytorchLinker +from pytensor.link.pytorch.linker import PytorchLinker from pytensor.tensor import blas as pt_blas from pytensor.tensor.type import tensor3 from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -25,12 +23,13 @@ def test_pytorch_BatchedDot(): ) out = pt_blas.BatchedDot()(a, b) fgraph = FunctionGraph([a, b], [out]) - compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + pytensor_pytorch_fn, _ = compare_pytorch_and_py( + fgraph, [get_test_value(i) for i in fgraph.inputs] + ) # A dimension mismatch should raise a TypeError for compatibility inputs = [get_test_value(a)[:-1], get_test_value(b)] - opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) - pytorch_mode = Mode(PytorchLinker(), opts) - pytensor_pytorch_fn = function(fgraph.inputs, fgraph.outputs, mode=pytorch_mode) + pytorch_mode_no_rewrites = Mode(PytorchLinker(), None) + pytensor_pytorch_fn.mode = pytorch_mode_no_rewrites with pytest.raises(TypeError): pytensor_pytorch_fn(*inputs) From 4deea70cc21a174a5eb93122fdb112c6867f1963 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Wed, 17 Jul 2024 16:01:06 +0800 Subject: [PATCH 09/10] Removed get_test_value --- tests/link/pytorch/test_blas.py | 19 ++++--------------- tests/link/pytorch/test_math.py | 15 +++++++-------- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/tests/link/pytorch/test_blas.py b/tests/link/pytorch/test_blas.py index ce64d53f08..82aba58b43 100644 --- a/tests/link/pytorch/test_blas.py +++ b/tests/link/pytorch/test_blas.py @@ -1,11 +1,8 @@ import numpy as np import pytest -from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import get_test_value -from pytensor.link.pytorch.linker import PytorchLinker from pytensor.tensor import blas as pt_blas from pytensor.tensor.type import tensor3 from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -14,22 +11,14 @@ def test_pytorch_BatchedDot(): # tensor3 . tensor3 a = tensor3("a") - a.tag.test_value = ( - np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) - ) + A = np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) b = tensor3("b") - b.tag.test_value = ( - np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) - ) + B = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) out = pt_blas.BatchedDot()(a, b) fgraph = FunctionGraph([a, b], [out]) - pytensor_pytorch_fn, _ = compare_pytorch_and_py( - fgraph, [get_test_value(i) for i in fgraph.inputs] - ) + pytensor_pytorch_fn, _ = compare_pytorch_and_py(fgraph, [A, B]) # A dimension mismatch should raise a TypeError for compatibility - inputs = [get_test_value(a)[:-1], get_test_value(b)] - pytorch_mode_no_rewrites = Mode(PytorchLinker(), None) - pytensor_pytorch_fn.mode = pytorch_mode_no_rewrites + inputs = [A[:-1], B] with pytest.raises(TypeError): pytensor_pytorch_fn(*inputs) diff --git a/tests/link/pytorch/test_math.py b/tests/link/pytorch/test_math.py index ece5471e65..c756d6a4ff 100644 --- a/tests/link/pytorch/test_math.py +++ b/tests/link/pytorch/test_math.py @@ -2,29 +2,28 @@ from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import get_test_value from pytensor.tensor.type import matrix, scalar, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py def test_pytorch_dot(): y = vector("y") - y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + y_value = np.r_[1.0, 2.0].astype(config.floatX) x = vector("x") - x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) + x_value = np.r_[3.0, 4.0].astype(config.floatX) A = matrix("A") - A.tag.test_value = np.array([[6, 3], [3, 0]], dtype=config.floatX) + A_value = np.array([[6, 3], [3, 0]], dtype=config.floatX) alpha = scalar("alpha") - alpha.tag.test_value = np.array(3.0, dtype=config.floatX) + alpha_value = np.array(3.0, dtype=config.floatX) beta = scalar("beta") - beta.tag.test_value = np.array(5.0, dtype=config.floatX) + beta_value = np.array(5.0, dtype=config.floatX) # 2D * 2D out = A.dot(A * alpha) + beta * A fgraph = FunctionGraph([A, alpha, beta], [out]) - compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_pytorch_and_py(fgraph, [A_value, alpha_value, beta_value]) # 1D * 2D and 1D * 1D out = y.dot(alpha * A).dot(x) + beta * y fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) - compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_pytorch_and_py(fgraph, [y_value, x_value, A_value, alpha_value, beta_value]) From cab9db8bf54233515e079e878f806595a7f52d2f Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Thu, 18 Jul 2024 09:26:58 +0800 Subject: [PATCH 10/10] Changed variable names --- tests/link/pytorch/test_blas.py | 8 ++++---- tests/link/pytorch/test_math.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/link/pytorch/test_blas.py b/tests/link/pytorch/test_blas.py index 82aba58b43..35f7dd7b6a 100644 --- a/tests/link/pytorch/test_blas.py +++ b/tests/link/pytorch/test_blas.py @@ -11,14 +11,14 @@ def test_pytorch_BatchedDot(): # tensor3 . tensor3 a = tensor3("a") - A = np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) + a_test = np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) b = tensor3("b") - B = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) + b_test = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) out = pt_blas.BatchedDot()(a, b) fgraph = FunctionGraph([a, b], [out]) - pytensor_pytorch_fn, _ = compare_pytorch_and_py(fgraph, [A, B]) + pytensor_pytorch_fn, _ = compare_pytorch_and_py(fgraph, [a_test, b_test]) # A dimension mismatch should raise a TypeError for compatibility - inputs = [A[:-1], B] + inputs = [a_test[:-1], b_test] with pytest.raises(TypeError): pytensor_pytorch_fn(*inputs) diff --git a/tests/link/pytorch/test_math.py b/tests/link/pytorch/test_math.py index c756d6a4ff..affca4ad32 100644 --- a/tests/link/pytorch/test_math.py +++ b/tests/link/pytorch/test_math.py @@ -8,22 +8,22 @@ def test_pytorch_dot(): y = vector("y") - y_value = np.r_[1.0, 2.0].astype(config.floatX) + y_test = np.r_[1.0, 2.0].astype(config.floatX) x = vector("x") - x_value = np.r_[3.0, 4.0].astype(config.floatX) + x_test = np.r_[3.0, 4.0].astype(config.floatX) A = matrix("A") - A_value = np.array([[6, 3], [3, 0]], dtype=config.floatX) + A_test = np.array([[6, 3], [3, 0]], dtype=config.floatX) alpha = scalar("alpha") - alpha_value = np.array(3.0, dtype=config.floatX) + alpha_test = np.array(3.0, dtype=config.floatX) beta = scalar("beta") - beta_value = np.array(5.0, dtype=config.floatX) + beta_test = np.array(5.0, dtype=config.floatX) # 2D * 2D out = A.dot(A * alpha) + beta * A fgraph = FunctionGraph([A, alpha, beta], [out]) - compare_pytorch_and_py(fgraph, [A_value, alpha_value, beta_value]) + compare_pytorch_and_py(fgraph, [A_test, alpha_test, beta_test]) # 1D * 2D and 1D * 1D out = y.dot(alpha * A).dot(x) + beta * y fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) - compare_pytorch_and_py(fgraph, [y_value, x_value, A_value, alpha_value, beta_value]) + compare_pytorch_and_py(fgraph, [y_test, x_test, A_test, alpha_test, beta_test])