Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,16 +1259,13 @@ def test_ngrams():
def test_compare_activations_to_torch(ops, dtype, x, torch_func):
import torch

def cast_torch(scalar: float):
return torch.tensor([scalar], requires_grad=True)

func_name, pytorch_func = torch_func
forward = getattr(ops, func_name)
backward = getattr(ops, "backprop_" + func_name)
# The tolerance of isclose is set to 1e-06 instead of
# the default 1e-08 due to the GELU
x_thinc = ops.asarray([x], dtype=dtype)
x_torch = cast_torch(x)
x_torch = xp2torch(x_thinc, requires_grad=True)
y = pytorch_func(x_torch)
y_thinc = forward(x_thinc)
y.backward()
Expand Down