-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Use torch.testing.assert_close in test_functional_tensor #3876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use torch.testing.assert_close in test_functional_tensor #3876
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM, but I'm wondering whether we could revisit the constraint regarding specifying both rtol
and atol
|
||
if scripted_fn_atol >= 0: | ||
scripted_fn = torch.jit.script(fn) | ||
# scriptable function test | ||
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs) | ||
self.assertTrue(transformed_batch.allclose(s_transformed_batch, atol=scripted_fn_atol)) | ||
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ref the rtol=1e-5
comes from the current default of https://pytorch.org/docs/stable/generated/torch.allclose.html
I'd prefer to leave rtol
to its default in assert_close
if possible, but rtol
must be set if atol
is set. Would you know the reason @pmeier ? np.testing.assert_allclose
doesn't have this constraint it seems
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We went for this logic, since we have non-zero defaults for rtol
and atol
. Imagine setting atol=0
and the tensors still match because rtol > 0
. See https://github.com/pytorch/pytorch/blob/74c12da4517c789bea737dc947d6adc755f63176/torch/testing/_asserts.py#L391-L396.
torch.max(true_out.float() - out.float()), | ||
1.0, | ||
torch.testing.assert_close( | ||
out, true_out, rtol=0.0, atol=1.0, check_stride=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we're comparing the max abs difference here instead of the max difference in master, but it's probably more correct with the proposed changes
…on into assert_close_func_tensor
) Summary: Co-authored-by: Philip Meier <[email protected]> Reviewed By: vincentqb, cpuhrsch Differential Revision: D28679977 fbshipit-source-id: 59cd7c52bd5a94a75141c719b85424bf075c65a6
part of #3865