Skip to content

Commit 5c198ce

Browse files
authored
Allow symbolic scalars in LinearDict (#7003)
* Parameterize LinearDict * Add protocol handlers * fix protocol handlers * fix protocol handlers, add test * format * tests * Y, Z gates, parameterize test * mypy * Revert changes to Scalar, and just use TParamValComplex everywhere. Add tests. * nits
1 parent 3f67923 commit 5c198ce

File tree

5 files changed

+197
-37
lines changed

5 files changed

+197
-37
lines changed

cirq-core/cirq/ops/common_gates.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,12 @@ def controlled(
258258
return result
259259

260260
def _pauli_expansion_(self) -> value.LinearDict[str]:
261-
if protocols.is_parameterized(self) or self._dimension != 2:
261+
if self._dimension != 2:
262262
return NotImplemented
263263
phase = 1j ** (2 * self._exponent * (self._global_shift + 0.5))
264-
angle = np.pi * self._exponent / 2
265-
return value.LinearDict({'I': phase * np.cos(angle), 'X': -1j * phase * np.sin(angle)})
264+
angle = _pi(self._exponent) * self._exponent / 2
265+
lib = sympy if protocols.is_parameterized(self) else np
266+
return value.LinearDict({'I': phase * lib.cos(angle), 'X': -1j * phase * lib.sin(angle)})
266267

267268
def _circuit_diagram_info_(
268269
self, args: 'cirq.CircuitDiagramInfoArgs'
@@ -464,11 +465,10 @@ def _trace_distance_bound_(self) -> Optional[float]:
464465
return abs(np.sin(self._exponent * 0.5 * np.pi))
465466

466467
def _pauli_expansion_(self) -> value.LinearDict[str]:
467-
if protocols.is_parameterized(self):
468-
return NotImplemented
469468
phase = 1j ** (2 * self._exponent * (self._global_shift + 0.5))
470-
angle = np.pi * self._exponent / 2
471-
return value.LinearDict({'I': phase * np.cos(angle), 'Y': -1j * phase * np.sin(angle)})
469+
angle = _pi(self._exponent) * self._exponent / 2
470+
lib = sympy if protocols.is_parameterized(self) else np
471+
return value.LinearDict({'I': phase * lib.cos(angle), 'Y': -1j * phase * lib.sin(angle)})
472472

473473
def _circuit_diagram_info_(
474474
self, args: 'cirq.CircuitDiagramInfoArgs'
@@ -764,11 +764,12 @@ def _trace_distance_bound_(self) -> Optional[float]:
764764
return abs(np.sin(self._exponent * 0.5 * np.pi))
765765

766766
def _pauli_expansion_(self) -> value.LinearDict[str]:
767-
if protocols.is_parameterized(self) or self._dimension != 2:
767+
if self._dimension != 2:
768768
return NotImplemented
769769
phase = 1j ** (2 * self._exponent * (self._global_shift + 0.5))
770-
angle = np.pi * self._exponent / 2
771-
return value.LinearDict({'I': phase * np.cos(angle), 'Z': -1j * phase * np.sin(angle)})
770+
angle = _pi(self._exponent) * self._exponent / 2
771+
lib = sympy if protocols.is_parameterized(self) else np
772+
return value.LinearDict({'I': phase * lib.cos(angle), 'Z': -1j * phase * lib.sin(angle)})
772773

773774
def _phase_by_(self, phase_turns: float, qubit_index: int):
774775
return self

cirq-core/cirq/ops/common_gates_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,3 +1300,13 @@ def test_wrong_dims():
13001300

13011301
with pytest.raises(ValueError, match='Wrong shape'):
13021302
_ = cirq.Z.on(cirq.LineQid(0, dimension=3))
1303+
1304+
1305+
@pytest.mark.parametrize('gate_type', [cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate])
1306+
@pytest.mark.parametrize('exponent', [sympy.Symbol('s'), sympy.Symbol('s') * 2])
1307+
def test_parameterized_pauli_expansion(gate_type, exponent):
1308+
gate = gate_type(exponent=exponent)
1309+
pauli = cirq.pauli_expansion(gate)
1310+
gate_resolved = cirq.resolve_parameters(gate, {'s': 0.5})
1311+
pauli_resolved = cirq.resolve_parameters(pauli, {'s': 0.5})
1312+
assert cirq.approx_eq(pauli_resolved, cirq.pauli_expansion(gate_resolved))

cirq-core/cirq/value/linear_dict.py

Lines changed: 80 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Linear combination represented as mapping of things to coefficients."""
1616

1717
from typing import (
18+
AbstractSet,
1819
Any,
1920
Callable,
2021
Dict,
@@ -28,21 +29,43 @@
2829
Optional,
2930
overload,
3031
Tuple,
32+
TYPE_CHECKING,
3133
TypeVar,
3234
Union,
3335
ValuesView,
3436
)
3537
from typing_extensions import Self
3638

3739
import numpy as np
40+
import sympy
41+
from cirq import protocols
42+
43+
if TYPE_CHECKING:
44+
import cirq
3845

3946
Scalar = Union[complex, np.number]
4047
TVector = TypeVar('TVector')
4148

4249
TDefault = TypeVar('TDefault')
4350

4451

45-
def _format_coefficient(format_spec: str, coefficient: Scalar) -> str:
52+
class _SympyPrinter(sympy.printing.str.StrPrinter):
53+
def __init__(self, format_spec: str):
54+
super().__init__()
55+
self._format_spec = format_spec
56+
57+
def _print(self, expr, **kwargs):
58+
if expr.is_complex:
59+
coefficient = complex(expr)
60+
s = _format_coefficient(self._format_spec, coefficient)
61+
return s[1:-1] if s.startswith('(') else s
62+
return super()._print(expr, **kwargs)
63+
64+
65+
def _format_coefficient(format_spec: str, coefficient: 'cirq.TParamValComplex') -> str:
66+
if isinstance(coefficient, sympy.Basic):
67+
printer = _SympyPrinter(format_spec)
68+
return printer.doprint(coefficient)
4669
coefficient = complex(coefficient)
4770
real_str = f'{coefficient.real:{format_spec}}'
4871
imag_str = f'{coefficient.imag:{format_spec}}'
@@ -59,7 +82,7 @@ def _format_coefficient(format_spec: str, coefficient: Scalar) -> str:
5982
return f'({real_str}+{imag_str}j)'
6083

6184

62-
def _format_term(format_spec: str, vector: TVector, coefficient: Scalar) -> str:
85+
def _format_term(format_spec: str, vector: TVector, coefficient: 'cirq.TParamValComplex') -> str:
6386
coefficient_str = _format_coefficient(format_spec, coefficient)
6487
if not coefficient_str:
6588
return coefficient_str
@@ -69,7 +92,7 @@ def _format_term(format_spec: str, vector: TVector, coefficient: Scalar) -> str:
6992
return '+' + result
7093

7194

72-
def _format_terms(terms: Iterable[Tuple[TVector, Scalar]], format_spec: str):
95+
def _format_terms(terms: Iterable[Tuple[TVector, 'cirq.TParamValComplex']], format_spec: str):
7396
formatted_terms = [_format_term(format_spec, vector, coeff) for vector, coeff in terms]
7497
s = ''.join(formatted_terms)
7598
if not s:
@@ -79,7 +102,7 @@ def _format_terms(terms: Iterable[Tuple[TVector, Scalar]], format_spec: str):
79102
return s
80103

81104

82-
class LinearDict(Generic[TVector], MutableMapping[TVector, Scalar]):
105+
class LinearDict(Generic[TVector], MutableMapping[TVector, 'cirq.TParamValComplex']):
83106
"""Represents linear combination of things.
84107
85108
LinearDict implements the basic linear algebraic operations of vector
@@ -96,7 +119,7 @@ class LinearDict(Generic[TVector], MutableMapping[TVector, Scalar]):
96119

97120
def __init__(
98121
self,
99-
terms: Optional[Mapping[TVector, Scalar]] = None,
122+
terms: Optional[Mapping[TVector, 'cirq.TParamValComplex']] = None,
100123
validator: Optional[Callable[[TVector], bool]] = None,
101124
) -> None:
102125
"""Initializes linear combination from a collection of terms.
@@ -112,21 +135,30 @@ def __init__(
112135
"""
113136
self._has_validator = validator is not None
114137
self._is_valid = validator or (lambda x: True)
115-
self._terms: Dict[TVector, Scalar] = {}
138+
self._terms: Dict[TVector, 'cirq.TParamValComplex'] = {}
116139
if terms is not None:
117140
self.update(terms)
118141

119142
@classmethod
120143
def fromkeys(cls, vectors, coefficient=0):
121-
return LinearDict(dict.fromkeys(vectors, complex(coefficient)))
144+
return LinearDict(
145+
dict.fromkeys(
146+
vectors,
147+
coefficient if isinstance(coefficient, sympy.Basic) else complex(coefficient),
148+
)
149+
)
122150

123151
def _check_vector_valid(self, vector: TVector) -> None:
124152
if not self._is_valid(vector):
125153
raise ValueError(f'{vector} is not compatible with linear combination {self}')
126154

127155
def clean(self, *, atol: float = 1e-9) -> Self:
128156
"""Remove terms with coefficients of absolute value atol or less."""
129-
negligible = [v for v, c in self._terms.items() if abs(complex(c)) <= atol]
157+
negligible = [
158+
v
159+
for v, c in self._terms.items()
160+
if not isinstance(c, sympy.Basic) and abs(complex(c)) <= atol
161+
]
130162
for v in negligible:
131163
del self._terms[v]
132164
return self
@@ -139,40 +171,50 @@ def keys(self) -> KeysView[TVector]:
139171
snapshot = self.copy().clean(atol=0)
140172
return snapshot._terms.keys()
141173

142-
def values(self) -> ValuesView[Scalar]:
174+
def values(self) -> ValuesView['cirq.TParamValComplex']:
143175
snapshot = self.copy().clean(atol=0)
144176
return snapshot._terms.values()
145177

146-
def items(self) -> ItemsView[TVector, Scalar]:
178+
def items(self) -> ItemsView[TVector, 'cirq.TParamValComplex']:
147179
snapshot = self.copy().clean(atol=0)
148180
return snapshot._terms.items()
149181

150182
# pylint: disable=function-redefined
151183
@overload
152-
def update(self, other: Mapping[TVector, Scalar], **kwargs: Scalar) -> None:
184+
def update(
185+
self, other: Mapping[TVector, 'cirq.TParamValComplex'], **kwargs: 'cirq.TParamValComplex'
186+
) -> None:
153187
pass
154188

155189
@overload
156-
def update(self, other: Iterable[Tuple[TVector, Scalar]], **kwargs: Scalar) -> None:
190+
def update(
191+
self,
192+
other: Iterable[Tuple[TVector, 'cirq.TParamValComplex']],
193+
**kwargs: 'cirq.TParamValComplex',
194+
) -> None:
157195
pass
158196

159197
@overload
160-
def update(self, *args: Any, **kwargs: Scalar) -> None:
198+
def update(self, *args: Any, **kwargs: 'cirq.TParamValComplex') -> None:
161199
pass
162200

163201
def update(self, *args, **kwargs):
164202
terms = dict()
165203
terms.update(*args, **kwargs)
166204
for vector, coefficient in terms.items():
205+
if isinstance(coefficient, sympy.Basic):
206+
coefficient = sympy.simplify(coefficient)
207+
if coefficient.is_complex:
208+
coefficient = complex(coefficient)
167209
self[vector] = coefficient
168210
self.clean(atol=0)
169211

170212
@overload
171-
def get(self, vector: TVector) -> Scalar:
213+
def get(self, vector: TVector) -> 'cirq.TParamValComplex':
172214
pass
173215

174216
@overload
175-
def get(self, vector: TVector, default: TDefault) -> Union[Scalar, TDefault]:
217+
def get(self, vector: TVector, default: TDefault) -> Union['cirq.TParamValComplex', TDefault]:
176218
pass
177219

178220
def get(self, vector, default=0):
@@ -185,10 +227,10 @@ def get(self, vector, default=0):
185227
def __contains__(self, vector: Any) -> bool:
186228
return vector in self._terms and self._terms[vector] != 0
187229

188-
def __getitem__(self, vector: TVector) -> Scalar:
230+
def __getitem__(self, vector: TVector) -> 'cirq.TParamValComplex':
189231
return self._terms.get(vector, 0)
190232

191-
def __setitem__(self, vector: TVector, coefficient: Scalar) -> None:
233+
def __setitem__(self, vector: TVector, coefficient: 'cirq.TParamValComplex') -> None:
192234
self._check_vector_valid(vector)
193235
if coefficient != 0:
194236
self._terms[vector] = coefficient
@@ -236,21 +278,21 @@ def __neg__(self) -> Self:
236278
factory = type(self)
237279
return factory({v: -c for v, c in self.items()})
238280

239-
def __imul__(self, a: Scalar) -> Self:
281+
def __imul__(self, a: 'cirq.TParamValComplex') -> Self:
240282
for vector in self:
241283
self._terms[vector] *= a
242284
self.clean(atol=0)
243285
return self
244286

245-
def __mul__(self, a: Scalar) -> Self:
287+
def __mul__(self, a: 'cirq.TParamValComplex') -> Self:
246288
result = self.copy()
247289
result *= a
248-
return result
290+
return result.copy()
249291

250-
def __rmul__(self, a: Scalar) -> Self: # type: ignore
292+
def __rmul__(self, a: 'cirq.TParamValComplex') -> Self:
251293
return self.__mul__(a)
252294

253-
def __truediv__(self, a: Scalar) -> Self:
295+
def __truediv__(self, a: 'cirq.TParamValComplex') -> Self:
254296
return self.__mul__(1 / a)
255297

256298
def __bool__(self) -> bool:
@@ -320,3 +362,19 @@ def _json_dict_(self) -> Dict[Any, Any]:
320362
@classmethod
321363
def _from_json_dict_(cls, keys, values, **kwargs):
322364
return cls(terms=dict(zip(keys, values)))
365+
366+
def _is_parameterized_(self) -> bool:
367+
return any(protocols.is_parameterized(v) for v in self._terms.values())
368+
369+
def _parameter_names_(self) -> AbstractSet[str]:
370+
return set(name for v in self._terms.values() for name in protocols.parameter_names(v))
371+
372+
def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'LinearDict':
373+
result = self.copy()
374+
result.update(
375+
{
376+
k: protocols.resolve_parameters(v, resolver, recursive)
377+
for k, v in self._terms.items()
378+
}
379+
)
380+
return result

0 commit comments

Comments
 (0)