1515"""Linear combination represented as mapping of things to coefficients."""
1616
1717from typing import (
18+ AbstractSet ,
1819 Any ,
1920 Callable ,
2021 Dict ,
2829 Optional ,
2930 overload ,
3031 Tuple ,
32+ TYPE_CHECKING ,
3133 TypeVar ,
3234 Union ,
3335 ValuesView ,
3436)
3537from typing_extensions import Self
3638
3739import numpy as np
40+ import sympy
41+ from cirq import protocols
42+
43+ if TYPE_CHECKING :
44+ import cirq
3845
3946Scalar = Union [complex , np .number ]
4047TVector = TypeVar ('TVector' )
4148
4249TDefault = TypeVar ('TDefault' )
4350
4451
45- def _format_coefficient (format_spec : str , coefficient : Scalar ) -> str :
52+ class _SympyPrinter (sympy .printing .str .StrPrinter ):
53+ def __init__ (self , format_spec : str ):
54+ super ().__init__ ()
55+ self ._format_spec = format_spec
56+
57+ def _print (self , expr , ** kwargs ):
58+ if expr .is_complex :
59+ coefficient = complex (expr )
60+ s = _format_coefficient (self ._format_spec , coefficient )
61+ return s [1 :- 1 ] if s .startswith ('(' ) else s
62+ return super ()._print (expr , ** kwargs )
63+
64+
65+ def _format_coefficient (format_spec : str , coefficient : 'cirq.TParamValComplex' ) -> str :
66+ if isinstance (coefficient , sympy .Basic ):
67+ printer = _SympyPrinter (format_spec )
68+ return printer .doprint (coefficient )
4669 coefficient = complex (coefficient )
4770 real_str = f'{ coefficient .real :{format_spec }} '
4871 imag_str = f'{ coefficient .imag :{format_spec }} '
@@ -59,7 +82,7 @@ def _format_coefficient(format_spec: str, coefficient: Scalar) -> str:
5982 return f'({ real_str } +{ imag_str } j)'
6083
6184
62- def _format_term (format_spec : str , vector : TVector , coefficient : Scalar ) -> str :
85+ def _format_term (format_spec : str , vector : TVector , coefficient : 'cirq.TParamValComplex' ) -> str :
6386 coefficient_str = _format_coefficient (format_spec , coefficient )
6487 if not coefficient_str :
6588 return coefficient_str
@@ -69,7 +92,7 @@ def _format_term(format_spec: str, vector: TVector, coefficient: Scalar) -> str:
6992 return '+' + result
7093
7194
72- def _format_terms (terms : Iterable [Tuple [TVector , Scalar ]], format_spec : str ):
95+ def _format_terms (terms : Iterable [Tuple [TVector , 'cirq.TParamValComplex' ]], format_spec : str ):
7396 formatted_terms = [_format_term (format_spec , vector , coeff ) for vector , coeff in terms ]
7497 s = '' .join (formatted_terms )
7598 if not s :
@@ -79,7 +102,7 @@ def _format_terms(terms: Iterable[Tuple[TVector, Scalar]], format_spec: str):
79102 return s
80103
81104
82- class LinearDict (Generic [TVector ], MutableMapping [TVector , Scalar ]):
105+ class LinearDict (Generic [TVector ], MutableMapping [TVector , 'cirq.TParamValComplex' ]):
83106 """Represents linear combination of things.
84107
85108 LinearDict implements the basic linear algebraic operations of vector
@@ -96,7 +119,7 @@ class LinearDict(Generic[TVector], MutableMapping[TVector, Scalar]):
96119
97120 def __init__ (
98121 self ,
99- terms : Optional [Mapping [TVector , Scalar ]] = None ,
122+ terms : Optional [Mapping [TVector , 'cirq.TParamValComplex' ]] = None ,
100123 validator : Optional [Callable [[TVector ], bool ]] = None ,
101124 ) -> None :
102125 """Initializes linear combination from a collection of terms.
@@ -112,21 +135,30 @@ def __init__(
112135 """
113136 self ._has_validator = validator is not None
114137 self ._is_valid = validator or (lambda x : True )
115- self ._terms : Dict [TVector , Scalar ] = {}
138+ self ._terms : Dict [TVector , 'cirq.TParamValComplex' ] = {}
116139 if terms is not None :
117140 self .update (terms )
118141
119142 @classmethod
120143 def fromkeys (cls , vectors , coefficient = 0 ):
121- return LinearDict (dict .fromkeys (vectors , complex (coefficient )))
144+ return LinearDict (
145+ dict .fromkeys (
146+ vectors ,
147+ coefficient if isinstance (coefficient , sympy .Basic ) else complex (coefficient ),
148+ )
149+ )
122150
123151 def _check_vector_valid (self , vector : TVector ) -> None :
124152 if not self ._is_valid (vector ):
125153 raise ValueError (f'{ vector } is not compatible with linear combination { self } ' )
126154
127155 def clean (self , * , atol : float = 1e-9 ) -> Self :
128156 """Remove terms with coefficients of absolute value atol or less."""
129- negligible = [v for v , c in self ._terms .items () if abs (complex (c )) <= atol ]
157+ negligible = [
158+ v
159+ for v , c in self ._terms .items ()
160+ if not isinstance (c , sympy .Basic ) and abs (complex (c )) <= atol
161+ ]
130162 for v in negligible :
131163 del self ._terms [v ]
132164 return self
@@ -139,40 +171,50 @@ def keys(self) -> KeysView[TVector]:
139171 snapshot = self .copy ().clean (atol = 0 )
140172 return snapshot ._terms .keys ()
141173
142- def values (self ) -> ValuesView [Scalar ]:
174+ def values (self ) -> ValuesView ['cirq.TParamValComplex' ]:
143175 snapshot = self .copy ().clean (atol = 0 )
144176 return snapshot ._terms .values ()
145177
146- def items (self ) -> ItemsView [TVector , Scalar ]:
178+ def items (self ) -> ItemsView [TVector , 'cirq.TParamValComplex' ]:
147179 snapshot = self .copy ().clean (atol = 0 )
148180 return snapshot ._terms .items ()
149181
150182 # pylint: disable=function-redefined
151183 @overload
152- def update (self , other : Mapping [TVector , Scalar ], ** kwargs : Scalar ) -> None :
184+ def update (
185+ self , other : Mapping [TVector , 'cirq.TParamValComplex' ], ** kwargs : 'cirq.TParamValComplex'
186+ ) -> None :
153187 pass
154188
155189 @overload
156- def update (self , other : Iterable [Tuple [TVector , Scalar ]], ** kwargs : Scalar ) -> None :
190+ def update (
191+ self ,
192+ other : Iterable [Tuple [TVector , 'cirq.TParamValComplex' ]],
193+ ** kwargs : 'cirq.TParamValComplex' ,
194+ ) -> None :
157195 pass
158196
159197 @overload
160- def update (self , * args : Any , ** kwargs : Scalar ) -> None :
198+ def update (self , * args : Any , ** kwargs : 'cirq.TParamValComplex' ) -> None :
161199 pass
162200
163201 def update (self , * args , ** kwargs ):
164202 terms = dict ()
165203 terms .update (* args , ** kwargs )
166204 for vector , coefficient in terms .items ():
205+ if isinstance (coefficient , sympy .Basic ):
206+ coefficient = sympy .simplify (coefficient )
207+ if coefficient .is_complex :
208+ coefficient = complex (coefficient )
167209 self [vector ] = coefficient
168210 self .clean (atol = 0 )
169211
170212 @overload
171- def get (self , vector : TVector ) -> Scalar :
213+ def get (self , vector : TVector ) -> 'cirq.TParamValComplex' :
172214 pass
173215
174216 @overload
175- def get (self , vector : TVector , default : TDefault ) -> Union [Scalar , TDefault ]:
217+ def get (self , vector : TVector , default : TDefault ) -> Union ['cirq.TParamValComplex' , TDefault ]:
176218 pass
177219
178220 def get (self , vector , default = 0 ):
@@ -185,10 +227,10 @@ def get(self, vector, default=0):
185227 def __contains__ (self , vector : Any ) -> bool :
186228 return vector in self ._terms and self ._terms [vector ] != 0
187229
188- def __getitem__ (self , vector : TVector ) -> Scalar :
230+ def __getitem__ (self , vector : TVector ) -> 'cirq.TParamValComplex' :
189231 return self ._terms .get (vector , 0 )
190232
191- def __setitem__ (self , vector : TVector , coefficient : Scalar ) -> None :
233+ def __setitem__ (self , vector : TVector , coefficient : 'cirq.TParamValComplex' ) -> None :
192234 self ._check_vector_valid (vector )
193235 if coefficient != 0 :
194236 self ._terms [vector ] = coefficient
@@ -236,21 +278,21 @@ def __neg__(self) -> Self:
236278 factory = type (self )
237279 return factory ({v : - c for v , c in self .items ()})
238280
239- def __imul__ (self , a : Scalar ) -> Self :
281+ def __imul__ (self , a : 'cirq.TParamValComplex' ) -> Self :
240282 for vector in self :
241283 self ._terms [vector ] *= a
242284 self .clean (atol = 0 )
243285 return self
244286
245- def __mul__ (self , a : Scalar ) -> Self :
287+ def __mul__ (self , a : 'cirq.TParamValComplex' ) -> Self :
246288 result = self .copy ()
247289 result *= a
248- return result
290+ return result . copy ()
249291
250- def __rmul__ (self , a : Scalar ) -> Self : # type: ignore
292+ def __rmul__ (self , a : 'cirq.TParamValComplex' ) -> Self :
251293 return self .__mul__ (a )
252294
253- def __truediv__ (self , a : Scalar ) -> Self :
295+ def __truediv__ (self , a : 'cirq.TParamValComplex' ) -> Self :
254296 return self .__mul__ (1 / a )
255297
256298 def __bool__ (self ) -> bool :
@@ -320,3 +362,19 @@ def _json_dict_(self) -> Dict[Any, Any]:
320362 @classmethod
321363 def _from_json_dict_ (cls , keys , values , ** kwargs ):
322364 return cls (terms = dict (zip (keys , values )))
365+
366+ def _is_parameterized_ (self ) -> bool :
367+ return any (protocols .is_parameterized (v ) for v in self ._terms .values ())
368+
369+ def _parameter_names_ (self ) -> AbstractSet [str ]:
370+ return set (name for v in self ._terms .values () for name in protocols .parameter_names (v ))
371+
372+ def _resolve_parameters_ (self , resolver : 'cirq.ParamResolver' , recursive : bool ) -> 'LinearDict' :
373+ result = self .copy ()
374+ result .update (
375+ {
376+ k : protocols .resolve_parameters (v , resolver , recursive )
377+ for k , v in self ._terms .items ()
378+ }
379+ )
380+ return result
0 commit comments