Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class ParamResolver:
Attributes:
param_dict: A dictionary from the ParameterValue key (str) to its
assigned value.

Raises:
TypeError if formulas are passed as keys.
"""

def __new__(cls, param_dict: 'cirq.ParamResolverOrSimilarType' = None):
Expand All @@ -68,6 +71,9 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None

self._param_hash: Optional[int] = None
self.param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
for key in self.param_dict:
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
self._deep_eval_map: ParamDictType = {}

def value_of(
Expand Down
20 changes: 3 additions & 17 deletions cirq-core/cirq/study/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,13 @@ def test_param_dict_iter():


def test_formulas_in_param_dict():
"""Test formulas in a `param_dict`.

Param dicts are allowed to have str or sympy.Symbol as keys and
floats or sympy.Symbol as values. This should not be a common use case,
but this tests makes sure something reasonable is returned when
mixing these types and using formulas in ParamResolvers.

Note that sympy orders expressions for deterministic resolution, so
depending on the operands sent to sub(), the expression may not fully
resolve if it needs to take several iterations of resolution.
"""
"""Tests that formula keys are rejected in a `param_dict`."""
a = sympy.Symbol('a')
b = sympy.Symbol('b')
c = sympy.Symbol('c')
e = sympy.Symbol('e')
r = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e})
assert sympy.Eq(r.value_of('a'), 3)
assert sympy.Eq(r.value_of('b'), 2)
assert sympy.Eq(r.value_of(b + c), 101)
assert sympy.Eq(r.value_of('c'), c)
assert sympy.Eq(r.value_of('d'), 2 * e)
with pytest.raises(TypeError, match='formula'):
_ = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e})


def test_recursive_evaluation():
Expand Down