2626from cirq import protocols , qis
2727from cirq ._doc import doc_private
2828from cirq .protocols .decompose_protocol import _try_decompose_into_operations_and_qubits
29+ from cirq .protocols .has_unitary_protocol import has_unitary
2930from cirq .protocols .mixture_protocol import has_mixture
3031from cirq .protocols .unitary_protocol import unitary
3132
@@ -95,9 +96,14 @@ def _has_kraus_(self) -> bool:
9596 """
9697
9798
98- def _strat_kraus_from_apply_channel (val : Any ) -> tuple [np .ndarray , ...] | None :
99+ def _strat_kraus_from_apply_channel (val : Any , atol : float ) -> tuple [np .ndarray , ...] | None :
99100 """Attempts to compute a value's Kraus operators via its _apply_channel_ method.
100- This is very expensive (O(16^N)), so only do this as a last resort."""
101+ This is very expensive (O(16^N)), so only do this as a last resort.
102+
103+ Args:
104+ val: value to calculate kraus channels from.
105+ atol: Absolute tolerance for super-operator calculation.
106+ Matrices with all entries less than this will be dropped."""
101107 method = getattr (val , '_apply_channel_' , None )
102108 if method is None :
103109 return None
@@ -122,12 +128,15 @@ def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None:
122128 if superop is None or superop is NotImplemented :
123129 return None
124130 n = np .prod (qid_shape ) ** 2
125- kraus_ops = qis .superoperator_to_kraus (superop .reshape ((n , n )))
131+ # Note that super-operator calculations can be numerically unstable
132+ # and we want to avoid returning kraus channels with "almost zero"
133+ # components
134+ kraus_ops = qis .superoperator_to_kraus (superop .reshape ((n , n )), atol = atol )
126135 return tuple (kraus_ops )
127136
128137
129138def kraus (
130- val : Any , default : Any = RaiseTypeErrorIfNotProvided
139+ val : Any , default : Any = RaiseTypeErrorIfNotProvided , atol : float = 1e-6
131140) -> tuple [np .ndarray , ...] | TDefault :
132141 r"""Returns a list of matrices describing the channel for the given value.
133142
@@ -149,6 +158,8 @@ def kraus(
149158 default: Determines the fallback behavior when `val` doesn't have
150159 a channel. If `default` is not set, a TypeError is raised. If
151160 default is set to a value, that value is returned.
161+ atol: If calculating Kraus channels from channels, use this tolerance
162+ for determining whether a super-operator is all zeros.
152163
153164 Returns:
154165 If `val` has a `_kraus_` method and its result is not NotImplemented,
@@ -187,6 +198,9 @@ def kraus(
187198 if unitary_result is not NotImplemented and unitary_result is not None :
188199 return (unitary_result ,)
189200
201+ if has_unitary (val ):
202+ return (unitary (val ),)
203+
190204 channel_result = NotImplemented if channel_getter is None else channel_getter ()
191205 if channel_result is not NotImplemented :
192206 return tuple (channel_result ) # pragma: no cover
@@ -195,7 +209,7 @@ def kraus(
195209 # Note: _apply_channel can lead to kraus being called again, so if default
196210 # is None, this can trigger an infinite loop.
197211 if default is not None :
198- result = _strat_kraus_from_apply_channel (val )
212+ result = _strat_kraus_from_apply_channel (val , atol )
199213 if result is not None :
200214 return result
201215
0 commit comments