diff --git a/environment.yml b/environment.yml index 54d6913fba..200be42c3f 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ channels: dependencies: - python>=3.10 - compilers - - numpy>=1.17.0 + - numpy>=1.17.0,<2 - scipy>=0.14,<1.14.0 - filelock - etuples diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index f39e108bed..0ddb25765f 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -2,6 +2,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @pytorch_funcify.register(Elemwise) @@ -34,3 +35,52 @@ def dimshuffle(x): return res return dimshuffle + + +@pytorch_funcify.register(Softmax) +def pytorch_funcify_Softmax(op, **kwargs): + axis = op.axis + dtype = kwargs["node"].inputs[0].dtype + + if not dtype.startswith("float"): + raise NotImplementedError( + "Pytorch Softmax is not currently implemented for non-float types." + ) + + def softmax(x): + if axis is not None: + return torch.softmax(x, dim=axis) + else: + return torch.softmax(x.ravel(), dim=0).reshape(x.shape) + + return softmax + + +@pytorch_funcify.register(LogSoftmax) +def pytorch_funcify_LogSoftmax(op, **kwargs): + axis = op.axis + dtype = kwargs["node"].inputs[0].dtype + + if not dtype.startswith("float"): + raise NotImplementedError( + "Pytorch LogSoftmax is not currently implemented for non-float types." + ) + + def log_softmax(x): + if axis is not None: + return torch.log_softmax(x, dim=axis) + else: + return torch.log_softmax(x.ravel(), dim=0).reshape(x.shape) + + return log_softmax + + +@pytorch_funcify.register(SoftmaxGrad) +def jax_funcify_SoftmaxGrad(op, **kwargs): + axis = op.axis + + def softmax_grad(dy, sm): + dy_times_sm = dy * sm + return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm + + return softmax_grad diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 1d843b8051..586789772f 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -1,9 +1,11 @@ import numpy as np +import pytest import pytensor.tensor as pt from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.tensor import elemwise as pt_elemwise +from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.type import matrix, tensor, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -53,3 +55,50 @@ def test_pytorch_elemwise(): fg = FunctionGraph([x], [out]) compare_pytorch_and_py(fg, [[0.9, 0.9]]) + + +@pytest.mark.parametrize("dtype", ["float64", "int64"]) +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_softmax(axis, dtype): + x = matrix("x", dtype=dtype) + out = softmax(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) + + if dtype == "int64": + with pytest.raises( + NotImplementedError, + match="Pytorch Softmax is not currently implemented for non-float types.", + ): + compare_pytorch_and_py(fgraph, [test_input]) + else: + compare_pytorch_and_py(fgraph, [test_input]) + + +@pytest.mark.parametrize("dtype", ["float64", "int64"]) +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_logsoftmax(axis, dtype): + x = matrix("x", dtype=dtype) + out = log_softmax(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) + + if dtype == "int64": + with pytest.raises( + NotImplementedError, + match="Pytorch LogSoftmax is not currently implemented for non-float types.", + ): + compare_pytorch_and_py(fgraph, [test_input]) + else: + compare_pytorch_and_py(fgraph, [test_input]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_softmax_grad(axis): + dy = matrix("dy") + dy_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) + sm = matrix("sm") + sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + out = SoftmaxGrad(axis=axis)(dy, sm) + fgraph = FunctionGraph([dy, sm], [out]) + compare_pytorch_and_py(fgraph, [dy_value, sm_value])