Skip to content

Use torch.testing.assert_close in test_functional_tensor #3876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
May 24, 2021

Conversation

NicolasHug
Copy link
Member

part of #3865

Copy link
Member Author

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, but I'm wondering whether we could revisit the constraint regarding specifying both rtol and atol


if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
self.assertTrue(transformed_batch.allclose(s_transformed_batch, atol=scripted_fn_atol))
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ref the rtol=1e-5 comes from the current default of https://pytorch.org/docs/stable/generated/torch.allclose.html

I'd prefer to leave rtol to its default in assert_close if possible, but rtol must be set if atol is set. Would you know the reason @pmeier ? np.testing.assert_allclose doesn't have this constraint it seems

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We went for this logic, since we have non-zero defaults for rtol and atol. Imagine setting atol=0 and the tensors still match because rtol > 0. See https://github.com/pytorch/pytorch/blob/74c12da4517c789bea737dc947d6adc755f63176/torch/testing/_asserts.py#L391-L396.

torch.max(true_out.float() - out.float()),
1.0,
torch.testing.assert_close(
out, true_out, rtol=0.0, atol=1.0, check_stride=False,
Copy link
Member Author

@NicolasHug NicolasHug May 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're comparing the max abs difference here instead of the max difference in master, but it's probably more correct with the proposed changes

@NicolasHug NicolasHug merged commit b96d381 into pytorch:master May 24, 2021
facebook-github-bot pushed a commit that referenced this pull request May 25, 2021
)

Summary: Co-authored-by: Philip Meier <[email protected]>

Reviewed By: vincentqb, cpuhrsch

Differential Revision: D28679977

fbshipit-source-id: 59cd7c52bd5a94a75141c719b85424bf075c65a6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants