10
10
from datetime import datetime
11
11
from decimal import Decimal
12
12
from functools import wraps
13
- from inspect import isclass
13
+ from inspect import getfullargspec , isclass
14
14
from typing import (
15
15
Any ,
16
16
Callable ,
17
17
ClassVar ,
18
18
Dict ,
19
+ Generic ,
20
+ Iterable ,
19
21
List ,
20
22
Optional ,
23
+ Set ,
21
24
Type ,
22
25
TypeVar ,
23
26
Union ,
27
+ cast ,
24
28
get_type_hints ,
25
29
)
26
30
31
+ import cbor2
32
+
33
+ from pycardano .logging import logger
34
+
35
+ # Remove the semantic decoder for 258 (CBOR tag for set) as we care about the order of elements
36
+ try :
37
+ cbor2 ._decoder .semantic_decoders .pop (258 )
38
+ except Exception as e :
39
+ logger .warning ("Failed to remove semantic decoder for CBOR tag 258" , e )
40
+ pass
41
+
27
42
from cbor2 import CBOREncoder , CBORSimpleValue , CBORTag , dumps , loads , undefined
28
43
from frozendict import frozendict
29
44
from frozenlist import FrozenList
44
59
"RawCBOR" ,
45
60
"list_hook" ,
46
61
"limit_primitive_type" ,
62
+ "OrderedSet" ,
63
+ "NonEmptyOrderedSet" ,
47
64
]
48
65
66
+ T = TypeVar ("T" )
67
+
49
68
50
69
def _identity (x ):
51
70
return x
@@ -314,10 +333,12 @@ def validate(self):
314
333
def _check_recursive (value , type_hint ):
315
334
if type_hint is Any :
316
335
return True
336
+
337
+ if isinstance (value , CBORSerializable ):
338
+ value .validate ()
339
+
317
340
origin = getattr (type_hint , "__origin__" , None )
318
341
if origin is None :
319
- if isinstance (value , CBORSerializable ):
320
- value .validate ()
321
342
return isinstance (value , type_hint )
322
343
elif origin is ClassVar :
323
344
return _check_recursive (value , type_hint .__args__ [0 ])
@@ -329,7 +350,7 @@ def _check_recursive(value, type_hint):
329
350
_check_recursive (k , key_type ) and _check_recursive (v , value_type )
330
351
for k , v in value .items ()
331
352
)
332
- elif origin in (list , set , tuple , frozenset ):
353
+ elif origin in (list , set , tuple , frozenset , OrderedSet ):
333
354
if value is None :
334
355
return True
335
356
args = type_hint .__args__
@@ -364,12 +385,15 @@ def to_validated_primitive(self) -> Primitive:
364
385
return self .to_primitive ()
365
386
366
387
@classmethod
367
- def from_primitive (cls : Type [CBORBase ], value : Any ) -> CBORBase :
388
+ def from_primitive (
389
+ cls : Type [CBORBase ], value : Any , type_args : Optional [tuple ] = None
390
+ ) -> CBORBase :
368
391
"""Turn a CBOR primitive to its original class type.
369
392
370
393
Args:
371
394
cls (CBORBase): The original class type.
372
395
value (:const:`Primitive`): A CBOR primitive.
396
+ type_args (Optional[tuple]): Type arguments for the class.
373
397
374
398
Returns:
375
399
CBORBase: A CBOR serializable object.
@@ -519,10 +543,26 @@ def _restore_typed_primitive(
519
543
Union[:const:`Primitive`, CBORSerializable]: A CBOR primitive or a CBORSerializable.
520
544
"""
521
545
546
+ is_cbor_serializable = False
547
+ try :
548
+ is_cbor_serializable = issubclass (t , CBORSerializable )
549
+ except TypeError :
550
+ # Handle the case when t is a generic alias
551
+ origin = typing .get_origin (t )
552
+ if origin is not None :
553
+ try :
554
+ is_cbor_serializable = issubclass (origin , CBORSerializable )
555
+ except TypeError :
556
+ pass
557
+
522
558
if t is Any or (t in PRIMITIVE_TYPES and isinstance (v , t )):
523
559
return v
524
- elif isclass (t ) and issubclass (t , CBORSerializable ):
525
- return t .from_primitive (v )
560
+ elif is_cbor_serializable :
561
+ if "type_args" in getfullargspec (t .from_primitive ).args :
562
+ args = typing .get_args (t )
563
+ return t .from_primitive (v , type_args = args )
564
+ else :
565
+ return t .from_primitive (v )
526
566
elif hasattr (t , "__origin__" ) and (t .__origin__ is list ):
527
567
t_args = t .__args__
528
568
if len (t_args ) != 1 :
@@ -941,3 +981,82 @@ def list_hook(
941
981
CBORSerializables.
942
982
"""
943
983
return lambda vals : [cls .from_primitive (v ) for v in vals ]
984
+
985
+
986
+ class OrderedSet (list , Generic [T ], CBORSerializable ):
987
+ def __init__ (self , iterable : Optional [List [T ]] = None , use_tag : bool = True ):
988
+ super ().__init__ ()
989
+ self ._set : Set [str ] = set ()
990
+ self ._use_tag = use_tag
991
+ if iterable :
992
+ self .extend (iterable )
993
+
994
+ def append (self , item : T ) -> None :
995
+ item_key = str (item )
996
+ if item_key not in self ._set :
997
+ super ().append (item )
998
+ self ._set .add (item_key )
999
+
1000
+ def extend (self , items : Iterable [T ]) -> None :
1001
+ for item in items :
1002
+ self .append (item )
1003
+
1004
+ def __contains__ (self , item : object ) -> bool :
1005
+ return str (item ) in self ._set
1006
+
1007
+ def __eq__ (self , other : object ) -> bool :
1008
+ if not isinstance (other , OrderedSet ):
1009
+ if isinstance (other , list ):
1010
+ return list (self ) == other
1011
+ return False
1012
+ return list (self ) == list (other )
1013
+
1014
+ def __repr__ (self ) -> str :
1015
+ return f"{ self .__class__ .__name__ } ({ list (self )} )"
1016
+
1017
+ def to_shallow_primitive (self ) -> Union [CBORTag , List [T ]]:
1018
+ if self ._use_tag :
1019
+ return CBORTag (258 , list (self ))
1020
+ return list (self )
1021
+
1022
+ @classmethod
1023
+ def from_primitive (
1024
+ cls : Type [OrderedSet [T ]], value : Primitive , type_args : Optional [tuple ] = None
1025
+ ) -> OrderedSet [T ]:
1026
+ assert (
1027
+ type_args is None or len (type_args ) == 1
1028
+ ), "OrderedSet should have exactly one type argument"
1029
+ # Retrieve the type arguments from the class
1030
+ type_arg = type_args [0 ] if type_args else None
1031
+
1032
+ if isinstance (value , CBORTag ) and value .tag == 258 :
1033
+ if isclass (type_arg ) and issubclass (type_arg , CBORSerializable ):
1034
+ value .value = [type_arg .from_primitive (v ) for v in value .value ]
1035
+ return cls (value .value , use_tag = True )
1036
+
1037
+ if isinstance (value , (list , tuple , set )):
1038
+ if isclass (type_arg ) and issubclass (type_arg , CBORSerializable ):
1039
+ value = [type_arg .from_primitive (v ) for v in value ]
1040
+ return cls (list (value ), use_tag = False )
1041
+
1042
+ raise ValueError (f"Cannot deserialize { value } to { cls } " )
1043
+
1044
+
1045
+ class NonEmptyOrderedSet (OrderedSet [T ]):
1046
+ def __init__ (self , iterable : Optional [List [T ]] = None , use_tag : bool = True ):
1047
+ super ().__init__ (iterable , use_tag )
1048
+
1049
+ def validate (self ):
1050
+ if not self :
1051
+ raise ValueError ("NonEmptyOrderedSet cannot be empty" )
1052
+
1053
+ @classmethod
1054
+ def from_primitive (
1055
+ cls : Type [NonEmptyOrderedSet [T ]],
1056
+ value : Primitive ,
1057
+ type_args : Optional [tuple ] = None ,
1058
+ ) -> NonEmptyOrderedSet [T ]:
1059
+ result = cast (NonEmptyOrderedSet [T ], super ().from_primitive (value , type_args ))
1060
+ if not result :
1061
+ raise ValueError ("NonEmptyOrderedSet cannot be empty" )
1062
+ return result
0 commit comments