Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def value_of(
Raises:
RecursionError: If the ParamResolver detects a loop in recursive
resolution.
sympy.SympifyError: If the resulting value cannot be interpreted.
"""

# Input is a pass through type, no resolution needed: return early
Expand Down Expand Up @@ -179,7 +180,12 @@ def value_of(
if not recursive:
# Resolves one step at a time. For example:
# a.subs({a: b, b: c}) == b
#
# Note that a sympy.SympifyError here likely means
# that one of the expressions was not parsable by sympy
# (such as a function returning NotImplemented)
v = value.subs(self.param_dict, simultaneous=True)

if v.free_symbols:
return v
elif sympy.im(v):
Expand Down
26 changes: 16 additions & 10 deletions cirq-core/cirq/study/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,25 +227,31 @@ class Foo:
def _resolved_value_(self):
return self

class Bar:
def _resolved_value_(self):
return NotImplemented

class Baz:
def _resolved_value_(self):
return 'Baz'

foo = Foo()
bar = Bar()
baz = Baz()

a = sympy.Symbol('a')
b = sympy.Symbol('b')
c = sympy.Symbol('c')
r = cirq.ParamResolver({a: foo, b: bar, c: baz})
b = sympy.Symbol('c')
r = cirq.ParamResolver({a: foo, b: baz})
assert r.value_of(a) is foo
assert r.value_of(b) is b
assert r.value_of(c) == 'Baz'
assert r.value_of(b) == 'Baz'


@pytest.mark.xfail(reason='this test requires sympy 1.12', strict=True)
def test_custom_value_not_implemented():
class Bar:
def _resolved_value_(self):
return NotImplemented

b = sympy.Symbol('b')
bar = Bar()
r = cirq.ParamResolver({b: bar})
with pytest.raises(sympy.SympifyError):
_ = r.value_of(b)


def test_compose():
Expand Down