Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions cirq-core/cirq/protocols/kraus_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from cirq import protocols, qis
from cirq._doc import doc_private
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
from cirq.protocols.has_unitary_protocol import has_unitary
from cirq.protocols.mixture_protocol import has_mixture
from cirq.protocols.unitary_protocol import unitary

Expand Down Expand Up @@ -95,7 +96,7 @@ def _has_kraus_(self) -> bool:
"""


def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None:
def _strat_kraus_from_apply_channel(val: Any, atol: float) -> tuple[np.ndarray, ...] | None:
Comment thread
mhucka marked this conversation as resolved.
"""Attempts to compute a value's Kraus operators via its _apply_channel_ method.
This is very expensive (O(16^N)), so only do this as a last resort."""
method = getattr(val, '_apply_channel_', None)
Expand All @@ -122,12 +123,15 @@ def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None:
if superop is None or superop is NotImplemented:
return None
n = np.prod(qid_shape) ** 2
kraus_ops = qis.superoperator_to_kraus(superop.reshape((n, n)))
# Note that super-operator calculations can be numerically unstable
# and we want to avoid returning kraus channels with "almost zero"
# components
kraus_ops = qis.superoperator_to_kraus(superop.reshape((n, n)), atol=atol)
return tuple(kraus_ops)


def kraus(
val: Any, default: Any = RaiseTypeErrorIfNotProvided
val: Any, default: Any = RaiseTypeErrorIfNotProvided, atol: float = 1e-6
) -> tuple[np.ndarray, ...] | TDefault:
r"""Returns a list of matrices describing the channel for the given value.

Expand All @@ -149,6 +153,8 @@ def kraus(
default: Determines the fallback behavior when `val` doesn't have
a channel. If `default` is not set, a TypeError is raised. If
default is set to a value, that value is returned.
atol: If calculating kraus channels from channels, use this tolerance
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: kraus → Kraus

for determining whether a super-operator is all zeros.

Returns:
If `val` has a `_kraus_` method and its result is not NotImplemented,
Expand Down Expand Up @@ -187,6 +193,9 @@ def kraus(
if unitary_result is not NotImplemented and unitary_result is not None:
return (unitary_result,)

if has_unitary(val):
return (unitary(val),)

channel_result = NotImplemented if channel_getter is None else channel_getter()
if channel_result is not NotImplemented:
return tuple(channel_result) # pragma: no cover
Expand All @@ -195,7 +204,7 @@ def kraus(
# Note: _apply_channel can lead to kraus being called again, so if default
# is None, this can trigger an infinite loop.
if default is not None:
result = _strat_kraus_from_apply_channel(val)
result = _strat_kraus_from_apply_channel(val, atol)
if result is not None:
return result

Expand Down
13 changes: 12 additions & 1 deletion cirq-core/cirq/protocols/kraus_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ class HasUnitary(cirq.testing.SingleQubitGate):
def _has_unitary_(self) -> bool:
return True

def _unitary_(self) -> np.ndarray:
return np.asarray([[1, 0], [0, 1]])


class HasKrausWhenDecomposed(cirq.testing.SingleQubitGate):
def __init__(self, decomposed_cls):
Expand All @@ -167,7 +170,7 @@ def test_has_kraus(cls) -> None:
assert cirq.has_kraus(cls())


@pytest.mark.parametrize('decomposed_cls', [HasKraus, HasMixture, HasUnitary])
@pytest.mark.parametrize('decomposed_cls', [HasKraus, HasMixture])
def test_has_kraus_when_decomposed(decomposed_cls) -> None:
op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test'))
assert cirq.has_kraus(op)
Expand Down Expand Up @@ -243,3 +246,11 @@ def _kraus_(self):
gate_no_kraus = NoKrausReset()
# Should still match the original superoperator
np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_kraus), atol=1e-8)


def test_kraus_channel_with_has_unitary():
"""CZSWAP is a gate with no unitary but has a unitary."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is something wrong with this docstring ;-)

op = cirq.CZSWAP.on(cirq.q(1), cirq.q(2))
channels = cirq.kraus(op)
assert len(channels) == 1
np.testing.assert_allclose(channels[0], cirq.unitary(op))
Loading