Skip to content

Commit 6f1a975

Browse files
dstrain115pavoljuhas
authored andcommitted
Fix kraus channels for fallbacks to super-operators (#7537)
- There were a few issues with the new strategy to fall back to super-operator calculations using apply_channels. - First, the super_operator_to_kraus function is generally not precise or numerically stable enough to support an atol of 1e-10 so loosened this to 1e-6 and also allowed the ability to specify this as a parameter. - Next, some operators define decomposition, so cirq.kraus should try to decompose and get a unitary before falling back to using apply_channel. Fixes: #7536
1 parent 7fd4a7b commit 6f1a975

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

cirq-core/cirq/circuits/moment_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ def test_kraus() -> None:
885885

886886

887887
def test_kraus_too_big() -> None:
888-
m = cirq.Moment(cirq.IdentityGate(11).on(*cirq.LineQubit.range(11)))
888+
m = cirq.Moment(cirq.IdentityGate(11).with_probability(0.5).on(*cirq.LineQubit.range(11)))
889889
assert not cirq.has_kraus(m)
890890
assert not m._has_superoperator_()
891891
assert m._kraus_() is NotImplemented

cirq-core/cirq/protocols/kraus_protocol.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from cirq import protocols, qis
2727
from cirq._doc import doc_private
2828
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
29+
from cirq.protocols.has_unitary_protocol import has_unitary
2930
from cirq.protocols.mixture_protocol import has_mixture
3031
from cirq.protocols.unitary_protocol import unitary
3132

@@ -95,9 +96,14 @@ def _has_kraus_(self) -> bool:
9596
"""
9697

9798

98-
def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None:
99+
def _strat_kraus_from_apply_channel(val: Any, atol: float) -> tuple[np.ndarray, ...] | None:
99100
"""Attempts to compute a value's Kraus operators via its _apply_channel_ method.
100-
This is very expensive (O(16^N)), so only do this as a last resort."""
101+
This is very expensive (O(16^N)), so only do this as a last resort.
102+
103+
Args:
104+
val: value to calculate kraus channels from.
105+
atol: Absolute tolerance for super-operator calculation.
106+
Matrices with all entries less than this will be dropped."""
101107
method = getattr(val, '_apply_channel_', None)
102108
if method is None:
103109
return None
@@ -122,12 +128,15 @@ def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None:
122128
if superop is None or superop is NotImplemented:
123129
return None
124130
n = np.prod(qid_shape) ** 2
125-
kraus_ops = qis.superoperator_to_kraus(superop.reshape((n, n)))
131+
# Note that super-operator calculations can be numerically unstable
132+
# and we want to avoid returning kraus channels with "almost zero"
133+
# components
134+
kraus_ops = qis.superoperator_to_kraus(superop.reshape((n, n)), atol=atol)
126135
return tuple(kraus_ops)
127136

128137

129138
def kraus(
130-
val: Any, default: Any = RaiseTypeErrorIfNotProvided
139+
val: Any, default: Any = RaiseTypeErrorIfNotProvided, atol: float = 1e-6
131140
) -> tuple[np.ndarray, ...] | TDefault:
132141
r"""Returns a list of matrices describing the channel for the given value.
133142
@@ -149,6 +158,8 @@ def kraus(
149158
default: Determines the fallback behavior when `val` doesn't have
150159
a channel. If `default` is not set, a TypeError is raised. If
151160
default is set to a value, that value is returned.
161+
atol: If calculating Kraus channels from channels, use this tolerance
162+
for determining whether a super-operator is all zeros.
152163
153164
Returns:
154165
If `val` has a `_kraus_` method and its result is not NotImplemented,
@@ -187,6 +198,9 @@ def kraus(
187198
if unitary_result is not NotImplemented and unitary_result is not None:
188199
return (unitary_result,)
189200

201+
if has_unitary(val):
202+
return (unitary(val),)
203+
190204
channel_result = NotImplemented if channel_getter is None else channel_getter()
191205
if channel_result is not NotImplemented:
192206
return tuple(channel_result) # pragma: no cover
@@ -195,7 +209,7 @@ def kraus(
195209
# Note: _apply_channel can lead to kraus being called again, so if default
196210
# is None, this can trigger an infinite loop.
197211
if default is not None:
198-
result = _strat_kraus_from_apply_channel(val)
212+
result = _strat_kraus_from_apply_channel(val, atol)
199213
if result is not None:
200214
return result
201215

cirq-core/cirq/protocols/kraus_protocol_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def test_has_kraus(cls) -> None:
167167
assert cirq.has_kraus(cls())
168168

169169

170-
@pytest.mark.parametrize('decomposed_cls', [HasKraus, HasMixture, HasUnitary])
170+
@pytest.mark.parametrize('decomposed_cls', [HasKraus, HasMixture])
171171
def test_has_kraus_when_decomposed(decomposed_cls) -> None:
172172
op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test'))
173173
assert cirq.has_kraus(op)
@@ -243,3 +243,11 @@ def _kraus_(self):
243243
gate_no_kraus = NoKrausReset()
244244
# Should still match the original superoperator
245245
np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_kraus), atol=1e-8)
246+
247+
248+
def test_kraus_channel_with_has_unitary():
249+
"""CZSWAP has no unitary dunder method but has_unitary returns True."""
250+
op = cirq.CZSWAP.on(cirq.q(1), cirq.q(2))
251+
channels = cirq.kraus(op)
252+
assert len(channels) == 1
253+
np.testing.assert_allclose(channels[0], cirq.unitary(op))

0 commit comments

Comments
 (0)