Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def value_of(
Raises:
RecursionError: If the ParamResolver detects a loop in recursive
resolution.
ValueError: If the resulting value cannot be interpreted.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update to sympy.SympifyError?

"""

# Input is a pass through type, no resolution needed: return early
Expand Down Expand Up @@ -179,7 +180,13 @@ def value_of(
if not recursive:
# Resolves one step at a time. For example:
# a.subs({a: b, b: c}) == b
v = value.subs(self.param_dict, simultaneous=True)
try:
v = value.subs(self.param_dict, simultaneous=True)
except sympy.SympifyError as e: # coverage: ignore
# Lines will be covered in sympy 1.12+
raise ValueError(
f'Could not resolve parameter {value}, underlying error {e}'
) # coverage: ignore
if v.free_symbols:
return v
elif sympy.im(v):
Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/study/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,11 @@ def _resolved_value_(self):
c = sympy.Symbol('c')
r = cirq.ParamResolver({a: foo, b: bar, c: baz})
assert r.value_of(a) is foo
assert r.value_of(b) is b
assert r.value_of(c) == 'Baz'

with pytest.raises(ValueError, match='Could not resolve parameter b'):
_ = r.value_of(b)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move this to a separate test function marked with a strict xfail, for example,

@pytest.mark.xfail(reason='this test requires sympy 1.12', strict=True)
def test_custom_value_not_implemented():
    ...

When the new sympy is released, the strict xfail will produce a CI failure
after which we can remove the xfail mark.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea done.



def test_compose():
"""Tests that cirq.resolve_parameters on a ParamResolver composes."""
Expand Down