Skip to content

Commit dee2c5c

Browse files
daxfohlpavoljuhas
andauthored
Allow symbolic scalars in linear combinations (#7030)
* Parameterize LinearDict * Add protocol handlers * fix protocol handlers * fix protocol handlers, add test * format * tests * Y, Z gates, parameterize test * mypy * Revert changes to Scalar, and just use TParamValComplex everywhere. Add tests. * Parameterize LinearCombinationOfGates * tests * tests * tests * tests * test * Update cirq-core/cirq/ops/raw_types.py Co-authored-by: Pavol Juhas <pavol.juhas@gmail.com> --------- Co-authored-by: Pavol Juhas <pavol.juhas@gmail.com>
1 parent 0e50a8c commit dee2c5c

File tree

7 files changed

+78
-23
lines changed

7 files changed

+78
-23
lines changed

cirq-core/cirq/linalg/operator_spaces.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313
# limitations under the License.
1414

1515
"""Utilities for manipulating linear operators as elements of vector space."""
16-
from typing import Dict, Tuple
16+
from typing import Dict, Tuple, TYPE_CHECKING
1717

1818
import numpy as np
19+
import sympy
1920

2021
from cirq import value
2122
from cirq._doc import document
2223

24+
if TYPE_CHECKING:
25+
import cirq
26+
2327
PAULI_BASIS = {
2428
'I': np.eye(2),
2529
'X': np.array([[0.0, 1.0], [1.0, 0.0]]),
@@ -78,8 +82,17 @@ def matrix_from_basis_coefficients(
7882

7983

8084
def pow_pauli_combination(
81-
ai: value.Scalar, ax: value.Scalar, ay: value.Scalar, az: value.Scalar, exponent: int
82-
) -> Tuple[value.Scalar, value.Scalar, value.Scalar, value.Scalar]:
85+
ai: 'cirq.TParamValComplex',
86+
ax: 'cirq.TParamValComplex',
87+
ay: 'cirq.TParamValComplex',
88+
az: 'cirq.TParamValComplex',
89+
exponent: int,
90+
) -> Tuple[
91+
'cirq.TParamValComplex',
92+
'cirq.TParamValComplex',
93+
'cirq.TParamValComplex',
94+
'cirq.TParamValComplex',
95+
]:
8396
"""Computes non-negative integer power of single-qubit Pauli combination.
8497
8598
Returns scalar coefficients bi, bx, by, bz such that
@@ -96,7 +109,10 @@ def pow_pauli_combination(
96109
if exponent == 0:
97110
return 1, 0, 0, 0
98111

99-
v = np.sqrt(ax * ax + ay * ay + az * az).item()
112+
if any(isinstance(a, sympy.Basic) for a in [ax, ay, az]):
113+
v = sympy.sqrt(ax * ax + ay * ay + az * az)
114+
else:
115+
v = np.sqrt(ax * ax + ay * ay + az * az).item()
100116
s = (ai + v) ** exponent
101117
t = (ai - v) ** exponent
102118

cirq-core/cirq/linalg/operator_spaces_test.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import pytest
1919
import scipy.linalg
20+
import sympy
2021

2122
import cirq
2223

@@ -287,11 +288,21 @@ def test_expand_is_inverse_of_reconstruct(m1, basis):
287288
(-1, -2, 3, 4),
288289
(1j, 2j, 3j, 4j),
289290
(1j, 2j, 3, 4),
291+
(sympy.Symbol('i'), sympy.Symbol('x'), sympy.Symbol('y'), sympy.Symbol('z')),
292+
(
293+
sympy.Symbol('i') * 1j,
294+
-sympy.Symbol('x'),
295+
-sympy.Symbol('y') * 1j,
296+
sympy.Symbol('z'),
297+
),
290298
),
291299
(0, 1, 2, 3, 4, 5, 100, 101),
292300
),
293301
)
294302
def test_pow_pauli_combination(coefficients, exponent):
303+
is_symbolic = any(isinstance(a, sympy.Basic) for a in coefficients)
304+
if is_symbolic and exponent > 2:
305+
return # too slow
295306
i = cirq.PAULI_BASIS['I']
296307
x = cirq.PAULI_BASIS['X']
297308
y = cirq.PAULI_BASIS['Y']
@@ -303,5 +314,7 @@ def test_pow_pauli_combination(coefficients, exponent):
303314

304315
bi, bx, by, bz = cirq.pow_pauli_combination(ai, ax, ay, az, exponent)
305316
result = bi * i + bx * x + by * y + bz * z
306-
307-
assert np.allclose(result, expected_result)
317+
if is_symbolic:
318+
assert cirq.approx_eq(result, expected_result)
319+
else:
320+
assert np.allclose(result, expected_result)

cirq-core/cirq/ops/common_gates.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ def _pauli_expansion_(self) -> value.LinearDict[str]:
261261
if self._dimension != 2:
262262
return NotImplemented
263263
phase = 1j ** (2 * self._exponent * (self._global_shift + 0.5))
264-
angle = _pi(self._exponent) * self._exponent / 2
265264
lib = sympy if protocols.is_parameterized(self) else np
265+
angle = lib.pi * self._exponent / 2
266266
return value.LinearDict({'I': phase * lib.cos(angle), 'X': -1j * phase * lib.sin(angle)})
267267

268268
def _circuit_diagram_info_(
@@ -466,8 +466,8 @@ def _trace_distance_bound_(self) -> Optional[float]:
466466

467467
def _pauli_expansion_(self) -> value.LinearDict[str]:
468468
phase = 1j ** (2 * self._exponent * (self._global_shift + 0.5))
469-
angle = _pi(self._exponent) * self._exponent / 2
470469
lib = sympy if protocols.is_parameterized(self) else np
470+
angle = lib.pi * self._exponent / 2
471471
return value.LinearDict({'I': phase * lib.cos(angle), 'Y': -1j * phase * lib.sin(angle)})
472472

473473
def _circuit_diagram_info_(
@@ -767,8 +767,8 @@ def _pauli_expansion_(self) -> value.LinearDict[str]:
767767
if self._dimension != 2:
768768
return NotImplemented
769769
phase = 1j ** (2 * self._exponent * (self._global_shift + 0.5))
770-
angle = _pi(self._exponent) * self._exponent / 2
771770
lib = sympy if protocols.is_parameterized(self) else np
771+
angle = lib.pi * self._exponent / 2
772772
return value.LinearDict({'I': phase * lib.cos(angle), 'Z': -1j * phase * lib.sin(angle)})
773773

774774
def _phase_by_(self, phase_turns: float, qubit_index: int):

cirq-core/cirq/ops/linear_combinations.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class LinearCombinationOfGates(value.LinearDict[raw_types.Gate]):
8181
2 * cirq.X - 2 * cirq.Z
8282
"""
8383

84-
def __init__(self, terms: Mapping[raw_types.Gate, value.Scalar]) -> None:
84+
def __init__(self, terms: Mapping[raw_types.Gate, 'cirq.TParamValComplex']) -> None:
8585
"""Initializes linear combination from a collection of terms.
8686
8787
Args:
@@ -149,17 +149,19 @@ def __pow__(self, exponent: int) -> 'LinearCombinationOfGates':
149149
)
150150

151151
def _is_parameterized_(self) -> bool:
152-
return any(protocols.is_parameterized(gate) for gate in self.keys())
152+
return any(protocols.is_parameterized(item) for item in self.items())
153153

154154
def _parameter_names_(self) -> AbstractSet[str]:
155-
return {name for gate in self.keys() for name in protocols.parameter_names(gate)}
155+
return {name for item in self.items() for name in protocols.parameter_names(item)}
156156

157157
def _resolve_parameters_(
158158
self, resolver: 'cirq.ParamResolver', recursive: bool
159159
) -> 'LinearCombinationOfGates':
160160
return self.__class__(
161161
{
162-
protocols.resolve_parameters(gate, resolver, recursive): coeff
162+
protocols.resolve_parameters(
163+
gate, resolver, recursive
164+
): protocols.resolve_parameters(coeff, resolver, recursive)
163165
for gate, coeff in self.items()
164166
}
165167
)
@@ -222,7 +224,7 @@ class LinearCombinationOfOperations(value.LinearDict[raw_types.Operation]):
222224
by the identity operator. Note that A may not be unitary or even normal.
223225
"""
224226

225-
def __init__(self, terms: Mapping[raw_types.Operation, value.Scalar]) -> None:
227+
def __init__(self, terms: Mapping[raw_types.Operation, 'cirq.TParamValComplex']) -> None:
226228
"""Initializes linear combination from a collection of terms.
227229
228230
Args:
@@ -264,17 +266,19 @@ def __pow__(self, exponent: int) -> 'LinearCombinationOfOperations':
264266
return LinearCombinationOfOperations({i: bi, x: bx, y: by, z: bz})
265267

266268
def _is_parameterized_(self) -> bool:
267-
return any(protocols.is_parameterized(op) for op in self.keys())
269+
return any(protocols.is_parameterized(item) for item in self.items())
268270

269271
def _parameter_names_(self) -> AbstractSet[str]:
270-
return {name for op in self.keys() for name in protocols.parameter_names(op)}
272+
return {name for item in self.items() for name in protocols.parameter_names(item)}
271273

272274
def _resolve_parameters_(
273275
self, resolver: 'cirq.ParamResolver', recursive: bool
274276
) -> 'LinearCombinationOfOperations':
275277
return self.__class__(
276278
{
277-
protocols.resolve_parameters(op, resolver, recursive): coeff
279+
protocols.resolve_parameters(op, resolver, recursive): protocols.resolve_parameters(
280+
coeff, resolver, recursive
281+
)
278282
for op, coeff in self.items()
279283
}
280284
)
@@ -353,7 +357,9 @@ def _is_linear_dict_of_unit_pauli_string(linear_dict: value.LinearDict[UnitPauli
353357
return True
354358

355359

356-
def _pauli_string_from_unit(unit: UnitPauliStringT, coefficient: Union[int, float, complex] = 1):
360+
def _pauli_string_from_unit(
361+
unit: UnitPauliStringT, coefficient: Union[int, float, 'cirq.TParamValComplex'] = 1
362+
):
357363
return PauliString(qubit_pauli_map=dict(unit), coefficient=coefficient)
358364

359365

cirq-core/cirq/ops/linear_combinations_test.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def test_non_unitary_linear_combination_of_gates_has_no_unitary(terms):
153153
),
154154
({cirq.X: 2, cirq.H: 1}, {'X': 2 + np.sqrt(0.5), 'Z': np.sqrt(0.5)}),
155155
({cirq.XX: -2, cirq.YY: 3j, cirq.ZZ: 4}, {'XX': -2, 'YY': 3j, 'ZZ': 4}),
156+
(
157+
{cirq.X: sympy.Symbol('x'), cirq.Y: -sympy.Symbol('y')},
158+
{'X': sympy.Symbol('x'), 'Y': -sympy.Symbol('y')},
159+
),
156160
),
157161
)
158162
def test_linear_combination_of_gates_has_correct_pauli_expansion(terms, expected_expansion):
@@ -206,7 +210,11 @@ def test_linear_combinations_of_gates_invalid_powers(terms, exponent):
206210

207211
@pytest.mark.parametrize(
208212
'terms, is_parameterized, parameter_names',
209-
[({cirq.H: 1}, False, set()), ({cirq.X ** sympy.Symbol('t'): 1}, True, {'t'})],
213+
[
214+
({cirq.H: 1}, False, set()),
215+
({cirq.X ** sympy.Symbol('t'): 1}, True, {'t'}),
216+
({cirq.X: sympy.Symbol('t')}, True, {'t'}),
217+
],
210218
)
211219
@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])
212220
def test_parameterized_linear_combination_of_gates(
@@ -225,7 +233,7 @@ def get_matrix(
225233
cirq.GateOperation,
226234
cirq.LinearCombinationOfGates,
227235
cirq.LinearCombinationOfOperations,
228-
]
236+
],
229237
) -> np.ndarray:
230238
if isinstance(operator, (cirq.LinearCombinationOfGates, cirq.LinearCombinationOfOperations)):
231239
return operator.matrix()
@@ -243,13 +251,13 @@ def assert_linear_combinations_are_equal(
243251

244252
actual_matrix = get_matrix(actual)
245253
expected_matrix = get_matrix(expected)
246-
assert np.allclose(actual_matrix, expected_matrix)
254+
assert cirq.approx_eq(actual_matrix, expected_matrix)
247255

248256
actual_expansion = cirq.pauli_expansion(actual)
249257
expected_expansion = cirq.pauli_expansion(expected)
250258
assert set(actual_expansion.keys()) == set(expected_expansion.keys())
251259
for name in actual_expansion.keys():
252-
assert abs(actual_expansion[name] - expected_expansion[name]) < 1e-12
260+
assert cirq.approx_eq(actual_expansion[name], expected_expansion[name])
253261

254262

255263
@pytest.mark.parametrize(
@@ -279,6 +287,8 @@ def assert_linear_combinations_are_equal(
279287
),
280288
((cirq.X + cirq.Y + cirq.Z) ** 0, cirq.I),
281289
((cirq.X - 1j * cirq.Y) ** 0, cirq.I),
290+
(cirq.Y - sympy.Symbol('s') * cirq.Y, (1 - sympy.Symbol('s')) * cirq.Y),
291+
((cirq.X + cirq.Z) * sympy.Symbol('s') / np.sqrt(2), cirq.H * sympy.Symbol('s')),
282292
),
283293
)
284294
def test_gate_expressions(expression, expected_result):
@@ -659,6 +669,10 @@ def test_non_unitary_linear_combination_of_operations_has_no_unitary(terms):
659669
{'IIZI': 1, 'IZII': 1, 'IZZI': -1},
660670
),
661671
({cirq.CNOT(q0, q1): 2, cirq.Z(q0): -1, cirq.X(q1): -1}, {'II': 1, 'ZX': -1}),
672+
(
673+
{cirq.X(q0): -sympy.Symbol('x'), cirq.Y(q0): sympy.Symbol('y')},
674+
{'X': -sympy.Symbol('x'), 'Y': sympy.Symbol('y')},
675+
),
662676
),
663677
)
664678
def test_linear_combination_of_operations_has_correct_pauli_expansion(terms, expected_expansion):
@@ -716,6 +730,7 @@ def test_linear_combinations_of_operations_invalid_powers(terms, exponent):
716730
[
717731
({cirq.H(cirq.LineQubit(0)): 1}, False, set()),
718732
({cirq.X(cirq.LineQubit(0)) ** sympy.Symbol('t'): 1}, True, {'t'}),
733+
({cirq.X(cirq.LineQubit(0)): sympy.Symbol('t')}, True, {'t'}),
719734
],
720735
)
721736
@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])
@@ -788,6 +803,10 @@ def test_parameterized_linear_combination_of_ops(
788803
cirq.LinearCombinationOfOperations({cirq.X(q1): 2, cirq.Z(q1): 3}) ** 0,
789804
cirq.LinearCombinationOfOperations({cirq.I(q1): 1}),
790805
),
806+
(
807+
cirq.LinearCombinationOfOperations({cirq.X(q0): sympy.Symbol('s')}) ** 2,
808+
cirq.LinearCombinationOfOperations({cirq.I(q0): sympy.Symbol('s') ** 2}),
809+
),
791810
),
792811
)
793812
def test_operation_expressions(expression, expected_result):

cirq-core/cirq/ops/raw_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation']
306306
return operations
307307

308308
def wrap_in_linear_combination(
309-
self, coefficient: Union[complex, float, int] = 1
309+
self, coefficient: 'cirq.TParamValComplex' = 1
310310
) -> 'cirq.LinearCombinationOfGates':
311311
"""Returns a LinearCombinationOfGates with this gate.
312312

cirq-core/cirq/ops/raw_types_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def _decompose_(self, qubits):
258258
(cirq.CZ * 1, cirq.CZ / 1),
259259
(-cirq.CSWAP * 1j, cirq.CSWAP / 1j),
260260
(cirq.TOFFOLI * 0.5, cirq.TOFFOLI / 2),
261+
(-cirq.X * sympy.Symbol('s'), -sympy.Symbol('s') * cirq.X),
261262
),
262263
)
263264
def test_gate_algebra(expression, expected_result):

0 commit comments

Comments
 (0)