4
4
import typing
5
5
import warnings
6
6
from collections import defaultdict
7
- from .utils import UNSPECIFIED , boolean_types , float_types , complex_types
8
- from .expr import Expr , make_constant , make_symbol , make_apply , known_expression_kinds
7
+ from .utils import UNSPECIFIED , boolean_types , float_types , complex_types , integer_types
8
+ from .expr import Expr , make_constant , make_symbol , make_apply , known_expression_kinds , make_list , make_item , make_len
9
9
from .typesystem import Type
10
10
11
11
@@ -194,11 +194,14 @@ def trace(self, func, *args):
194
194
if ":" in a :
195
195
name , annot = a .split (":" , 1 )
196
196
name = name .strip ()
197
+ # TODO: parse annot as it may contain typing alias
197
198
param = param .replace (name = name if name else param .name , annotation = annot .strip ())
198
199
else :
199
200
param = param .replace (name = a .strip ())
200
201
elif isinstance (a , type ):
201
202
param = param .replace (annotation = a .__name__ )
203
+ elif isinstance (a , types .GenericAlias ):
204
+ param = param .replace (annotation = a )
202
205
else :
203
206
raise NotImplementedError ((a , type (a )))
204
207
if param .annotation is inspect .Parameter .empty :
@@ -207,12 +210,28 @@ def trace(self, func, *args):
207
210
typ = param .annotation
208
211
if isinstance (typ , types .UnionType ):
209
212
typ = typing .get_args (typ )[0 ]
210
- assert isinstance (typ , (type , str )), typ
213
+ assert isinstance (typ , (type , str , types . GenericAlias )), ( type ( typ ), typ )
211
214
if default_typ is UNSPECIFIED :
212
215
default_typ = typ
213
- a = self .symbol (param .name , typ ).reference (ref_name = param .name )
216
+ if isinstance (typ , types .GenericAlias ):
217
+ if typ .__name__ == "list" :
218
+ a = self .list (
219
+ [
220
+ self .symbol (f"{ param .name } _{ k } _" , t ).reference (
221
+ ref_name = f"{ param .name } _{ k } _" ,
222
+ force = False , # reference to item will be defined only when it is used
223
+ )
224
+ for k , t in enumerate (typ .__args__ )
225
+ ]
226
+ )
227
+ else :
228
+ raise TypeError (f"annotation type must be type of a scalar or list, got { typ } " )
229
+ else :
230
+ a = self .symbol (param .name , typ )
231
+ a = a .reference (ref_name = param .name )
214
232
new_args .append (a )
215
233
args = tuple (new_args )
234
+
216
235
name = self .symbol (func .__name__ ).reference (ref_name = func .__name__ )
217
236
return make_apply (self , name , args , func (self , * args ))
218
237
@@ -233,6 +252,7 @@ def symbol(self, name, typ=UNSPECIFIED):
233
252
if typ is UNSPECIFIED :
234
253
like = self .default_like
235
254
if like is not None :
255
+ assert like .kind == "symbol" , like .kind
236
256
typ = like .operands [1 ]
237
257
else :
238
258
typ = "float"
@@ -243,15 +263,31 @@ def constant(self, value, like_expr=UNSPECIFIED):
243
263
if isinstance (value , boolean_types ):
244
264
like_expr = self .symbol ("_boolean_value" , "boolean" )
245
265
elif self ._default_constant_type is not None :
266
+ # Warning: when specified and default_constant_type is
267
+ # float, integer values will be interpreted as floats
246
268
like_expr = self .symbol ("_value" , self ._default_constant_type )
269
+ elif isinstance (value , integer_types ):
270
+ like_expr = self .symbol ("_integer_value" , type (value ))
247
271
elif isinstance (value , float_types ):
248
272
like_expr = self .symbol ("_float_value" , type (value ))
249
273
elif isinstance (value , complex_types ):
250
274
like_expr = self .symbol ("_complex_value" , type (value ))
251
275
else :
252
276
like_expr = self .default_like
277
+ elif isinstance (like_expr , (str , type , Type )):
278
+ typ = Type .fromobject (self , like_expr )
279
+ like_expr = self .symbol (f"_{ typ .kind } _value" , typ )
253
280
return make_constant (self , value , like_expr )
254
281
282
+ def list (self , items ):
283
+ return make_list (self , items )
284
+
285
+ def item (self , container , index ):
286
+ return make_item (self , container , index )
287
+
288
+ def len (self , container ):
289
+ return make_len (self , container )
290
+
255
291
def call (self , func , args ):
256
292
"""Apply callable to arguments and return its result.
257
293
0 commit comments