1313# limitations under the License.
1414"""A typed time delta that supports picosecond accuracy."""
1515
16- from typing import AbstractSet , Any , Dict , Optional , Tuple , TYPE_CHECKING , Union
16+ from typing import AbstractSet , Any , Dict , Optional , Tuple , TYPE_CHECKING , Union , List
1717import datetime
1818
1919import sympy
2020import numpy as np
2121
2222from cirq import protocols
23- from cirq ._compat import proper_repr
23+ from cirq ._compat import proper_repr , cached_method
2424from cirq ._doc import document
2525
2626if TYPE_CHECKING :
@@ -79,48 +79,53 @@ def __init__(
7979 >>> print(cirq.Duration(micros=1.5 * sympy.Symbol('t')))
8080 (1500.0*t) ns
8181 """
82+ self ._time_vals : List [_NUMERIC_INPUT_TYPE ] = [0 , 0 , 0 , 0 ]
83+ self ._multipliers = [1 , 1000 , 1000_000 , 1000_000_000 ]
8284 if value is not None and value != 0 :
8385 if isinstance (value , datetime .timedelta ):
8486 # timedelta has microsecond resolution.
85- micros + = int (value / datetime .timedelta (microseconds = 1 ))
87+ self . _time_vals [ 2 ] = int (value / datetime .timedelta (microseconds = 1 ))
8688 elif isinstance (value , Duration ):
87- picos + = value ._picos
89+ self . _time_vals = value ._time_vals
8890 else :
8991 raise TypeError (f'Not a `cirq.DURATION_LIKE`: { repr (value )} .' )
90-
91- val = picos + nanos * 1000 + micros * 1000_000 + millis * 1000_000_000
92- self ._picos : _NUMERIC_OUTPUT_TYPE = float (val ) if isinstance (val , np .number ) else val
92+ input_vals = [picos , nanos , micros , millis ]
93+ self ._time_vals = _add_time_vals (self ._time_vals , input_vals )
9394
9495 def _is_parameterized_ (self ) -> bool :
95- return protocols .is_parameterized (self ._picos )
96+ return protocols .is_parameterized (self ._time_vals )
9697
9798 def _parameter_names_ (self ) -> AbstractSet [str ]:
98- return protocols .parameter_names (self ._picos )
99+ return protocols .parameter_names (self ._time_vals )
99100
100101 def _resolve_parameters_ (self , resolver : 'cirq.ParamResolver' , recursive : bool ) -> 'Duration' :
101- return Duration (picos = protocols .resolve_parameters (self ._picos , resolver , recursive ))
102+ return _duration_from_time_vals (
103+ protocols .resolve_parameters (self ._time_vals , resolver , recursive )
104+ )
102105
106+ @cached_method
103107 def total_picos (self ) -> _NUMERIC_OUTPUT_TYPE :
104108 """Returns the number of picoseconds that the duration spans."""
105- return self ._picos
109+ val = sum (a * b for a , b in zip (self ._time_vals , self ._multipliers ))
110+ return float (val ) if isinstance (val , np .number ) else val
106111
107112 def total_nanos (self ) -> _NUMERIC_OUTPUT_TYPE :
108113 """Returns the number of nanoseconds that the duration spans."""
109- return self ._picos / 1000
114+ return self .total_picos () / 1000
110115
111116 def total_micros (self ) -> _NUMERIC_OUTPUT_TYPE :
112117 """Returns the number of microseconds that the duration spans."""
113- return self ._picos / 1000_000
118+ return self .total_picos () / 1000_000
114119
115120 def total_millis (self ) -> _NUMERIC_OUTPUT_TYPE :
116121 """Returns the number of milliseconds that the duration spans."""
117- return self ._picos / 1000_000_000
122+ return self .total_picos () / 1000_000_000
118123
119124 def __add__ (self , other ) -> 'Duration' :
120125 other = _attempt_duration_like_to_duration (other )
121126 if other is None :
122127 return NotImplemented
123- return Duration ( picos = self ._picos + other ._picos )
128+ return _duration_from_time_vals ( _add_time_vals ( self ._time_vals , other ._time_vals ) )
124129
125130 def __radd__ (self , other ) -> 'Duration' :
126131 return self .__add__ (other )
@@ -129,86 +134,94 @@ def __sub__(self, other) -> 'Duration':
129134 other = _attempt_duration_like_to_duration (other )
130135 if other is None :
131136 return NotImplemented
132- return Duration (picos = self ._picos - other ._picos )
137+ return _duration_from_time_vals (
138+ _add_time_vals (self ._time_vals , [- x for x in other ._time_vals ])
139+ )
133140
134141 def __rsub__ (self , other ) -> 'Duration' :
135142 other = _attempt_duration_like_to_duration (other )
136143 if other is None :
137144 return NotImplemented
138- return Duration (picos = other ._picos - self ._picos )
145+ return _duration_from_time_vals (
146+ _add_time_vals (other ._time_vals , [- x for x in self ._time_vals ])
147+ )
139148
140149 def __mul__ (self , other ) -> 'Duration' :
141150 if not isinstance (other , (int , float , sympy .Expr )):
142151 return NotImplemented
143- return Duration (picos = self ._picos * other )
152+ if other == 0 :
153+ return _duration_from_time_vals ([0 ] * 4 )
154+ return _duration_from_time_vals ([x * other for x in self ._time_vals ])
144155
145156 def __rmul__ (self , other ) -> 'Duration' :
146157 return self .__mul__ (other )
147158
148159 def __truediv__ (self , other ) -> Union ['Duration' , float ]:
149160 if isinstance (other , (int , float , sympy .Expr )):
150- return Duration (picos = self ._picos / other )
161+ new_time_vals = [x / other for x in self ._time_vals ]
162+ return _duration_from_time_vals (new_time_vals )
151163
152164 other_duration = _attempt_duration_like_to_duration (other )
153165 if other_duration is not None :
154- return self ._picos / other_duration ._picos
166+ return self .total_picos () / other_duration .total_picos ()
155167
156168 return NotImplemented
157169
158170 def __eq__ (self , other ):
159171 other = _attempt_duration_like_to_duration (other )
160172 if other is None :
161173 return NotImplemented
162- return self ._picos == other ._picos
174+ return self .total_picos () == other .total_picos ()
163175
164176 def __ne__ (self , other ):
165177 other = _attempt_duration_like_to_duration (other )
166178 if other is None :
167179 return NotImplemented
168- return self ._picos != other ._picos
180+ return self .total_picos () != other .total_picos ()
169181
170182 def __gt__ (self , other ):
171183 other = _attempt_duration_like_to_duration (other )
172184 if other is None :
173185 return NotImplemented
174- return self ._picos > other ._picos
186+ return self .total_picos () > other .total_picos ()
175187
176188 def __lt__ (self , other ):
177189 other = _attempt_duration_like_to_duration (other )
178190 if other is None :
179191 return NotImplemented
180- return self ._picos < other ._picos
192+ return self .total_picos () < other .total_picos ()
181193
182194 def __ge__ (self , other ):
183195 other = _attempt_duration_like_to_duration (other )
184196 if other is None :
185197 return NotImplemented
186- return self ._picos >= other ._picos
198+ return self .total_picos () >= other .total_picos ()
187199
188200 def __le__ (self , other ):
189201 other = _attempt_duration_like_to_duration (other )
190202 if other is None :
191203 return NotImplemented
192- return self ._picos <= other ._picos
204+ return self .total_picos () <= other .total_picos ()
193205
194206 def __bool__ (self ):
195- return bool (self ._picos )
207+ return bool (self .total_picos () )
196208
197209 def __hash__ (self ):
198- if isinstance (self ._picos , (int , float )) and self ._picos % 1000000 == 0 :
199- return hash (datetime .timedelta (microseconds = self ._picos / 1000000 ))
200- return hash ((Duration , self ._picos ))
210+ if isinstance (self .total_picos () , (int , float )) and self .total_picos () % 1000000 == 0 :
211+ return hash (datetime .timedelta (microseconds = self .total_picos () / 1000000 ))
212+ return hash ((Duration , self .total_picos () ))
201213
202214 def _decompose_into_amount_unit_suffix (self ) -> Tuple [int , str , str ]:
215+ picos = self .total_picos ()
203216 if (
204- isinstance (self . _picos , sympy .Mul )
205- and len (self . _picos .args ) == 2
206- and isinstance (self . _picos .args [0 ], (sympy .Integer , sympy .Float ))
217+ isinstance (picos , sympy .Mul )
218+ and len (picos .args ) == 2
219+ and isinstance (picos .args [0 ], (sympy .Integer , sympy .Float ))
207220 ):
208- scale = self . _picos .args [0 ]
209- rest = self . _picos .args [1 ]
221+ scale = picos .args [0 ]
222+ rest = picos .args [1 ]
210223 else :
211- scale = self . _picos
224+ scale = picos
212225 rest = 1
213226
214227 if scale % 1000_000_000 == 0 :
@@ -234,7 +247,7 @@ def _decompose_into_amount_unit_suffix(self) -> Tuple[int, str, str]:
234247 return amount * rest , unit , suffix
235248
236249 def __str__ (self ) -> str :
237- if self ._picos == 0 :
250+ if self .total_picos () == 0 :
238251 return 'Duration(0)'
239252 amount , _ , suffix = self ._decompose_into_amount_unit_suffix ()
240253 if not isinstance (amount , (int , float , sympy .Symbol )):
@@ -257,3 +270,21 @@ def _attempt_duration_like_to_duration(value: Any) -> Optional[Duration]:
257270 if isinstance (value , (int , float )) and value == 0 :
258271 return Duration ()
259272 return None
273+
274+
275+ def _add_time_vals (
276+ val1 : List [_NUMERIC_INPUT_TYPE ], val2 : List [_NUMERIC_INPUT_TYPE ]
277+ ) -> List [_NUMERIC_INPUT_TYPE ]:
278+ ret : List [_NUMERIC_INPUT_TYPE ] = []
279+ for i in range (4 ):
280+ if val1 [i ] and val2 [i ]:
281+ ret .append (val1 [i ] + val2 [i ])
282+ else :
283+ ret .append (val1 [i ] or val2 [i ])
284+ return ret
285+
286+
287+ def _duration_from_time_vals (time_vals : List [_NUMERIC_INPUT_TYPE ]):
288+ ret = Duration ()
289+ ret ._time_vals = time_vals
290+ return ret
0 commit comments