14
14
See the License for the specific language governing permissions and
15
15
limitations under the License.
16
16
"""
17
- from collections import defaultdict
18
17
from inspect import isclass
19
- from typing import Callable , ClassVar , Dict , List , Mapping , MutableMapping , Optional , Text , Tuple , Type , Union
18
+ from typing import Callable , ClassVar , Dict , List , Mapping , Optional , Text , Tuple , Type , Union
20
19
21
20
from pydantic import BaseModel , PrivateAttr
22
21
import structlog # type: ignore
25
24
from .enum import DiffSyncModelFlags , DiffSyncFlags , DiffSyncStatus
26
25
from .exceptions import DiffClassMismatch , ObjectAlreadyExists , ObjectStoreWrongType , ObjectNotFound
27
26
from .helpers import DiffSyncDiffer , DiffSyncSyncer
27
+ from .store import BaseStore
28
+ from .store .local import LocalStore
28
29
29
30
30
31
class DiffSyncModel (BaseModel ):
@@ -408,19 +409,17 @@ class DiffSync:
408
409
top_level : ClassVar [List [str ]] = []
409
410
"""List of top-level modelnames to begin from when diffing or synchronizing."""
410
411
411
- _data : MutableMapping [str , MutableMapping [str , DiffSyncModel ]]
412
- """Defaultdict storing model instances.
413
-
414
- `self._data[modelname][unique_id] == model_instance`
415
- """
416
-
417
- def __init__ (self , name = None ):
412
+ def __init__ (self , name = None , internal_storage_engine = LocalStore ):
418
413
"""Generic initialization function.
419
414
420
415
Subclasses should be careful to call super().__init__() if they override this method.
421
416
"""
422
- self ._data = defaultdict (dict )
423
- self ._log = structlog .get_logger ().new (diffsync = self )
417
+
418
+ if isinstance (internal_storage_engine , BaseStore ):
419
+ self .store = internal_storage_engine
420
+ self .store .diffsync = self
421
+ else :
422
+ self .store = internal_storage_engine (diffsync = self )
424
423
425
424
# If the type is not defined, use the name of the class as the default value
426
425
if self .type is None :
@@ -458,8 +457,8 @@ def __repr__(self):
458
457
return f"<{ str (self )} >"
459
458
460
459
def __len__ (self ):
461
- """Total number of elements stored in self._data ."""
462
- return sum ( len ( entries ) for entries in self ._data . values () )
460
+ """Total number of elements stored."""
461
+ return self .store . count ( )
463
462
464
463
def load (self ):
465
464
"""Load all desired data from whatever backend data source into this instance."""
@@ -468,10 +467,10 @@ def load(self):
468
467
def dict (self , exclude_defaults : bool = True , ** kwargs ) -> Mapping :
469
468
"""Represent the DiffSync contents as a dict, as if it were a Pydantic model."""
470
469
data : Dict [str , Dict [str , Dict ]] = {}
471
- for modelname in self ._data :
470
+ for modelname in self .store . get_all_model_names () :
472
471
data [modelname ] = {}
473
- for unique_id , model in self ._data [ modelname ]. items ( ):
474
- data [modelname ][ unique_id ] = model .dict (exclude_defaults = exclude_defaults , ** kwargs )
472
+ for obj in self .store . get_all ( model = modelname ):
473
+ data [obj . get_type ()][ obj . get_unique_id () ] = obj .dict (exclude_defaults = exclude_defaults , ** kwargs )
475
474
return data
476
475
477
476
def str (self , indent : int = 0 ) -> str :
@@ -615,9 +614,18 @@ def diff_to(
615
614
# Object Storage Management
616
615
# ------------------------------------------------------------------------------
617
616
617
+ def get_all_model_names (self ):
618
+ """Get all model names.
619
+
620
+ Returns:
621
+ List[str]: List of model names
622
+ """
623
+ return self .store .get_all_model_names ()
624
+
618
625
def get (
619
626
self , obj : Union [Text , DiffSyncModel , Type [DiffSyncModel ]], identifier : Union [Text , Mapping ]
620
627
) -> DiffSyncModel :
628
+
621
629
"""Get one object from the data store based on its unique id.
622
630
623
631
Args:
@@ -628,29 +636,7 @@ def get(
628
636
ValueError: if obj is a str and identifier is a dict (can't convert dict into a uid str without a model class)
629
637
ObjectNotFound: if the requested object is not present
630
638
"""
631
- if isinstance (obj , str ):
632
- modelname = obj
633
- if not hasattr (self , obj ):
634
- object_class = None
635
- else :
636
- object_class = getattr (self , obj )
637
- else :
638
- object_class = obj
639
- modelname = obj .get_type ()
640
-
641
- if isinstance (identifier , str ):
642
- uid = identifier
643
- elif object_class :
644
- uid = object_class .create_unique_id (** identifier )
645
- else :
646
- raise ValueError (
647
- f"Invalid args: ({ obj } , { identifier } ): "
648
- f"either { obj } should be a class/instance or { identifier } should be a str"
649
- )
650
-
651
- if uid not in self ._data [modelname ]:
652
- raise ObjectNotFound (f"{ modelname } { uid } not present in { self .name } " )
653
- return self ._data [modelname ][uid ]
639
+ return self .store .get (model = obj , identifier = identifier )
654
640
655
641
def get_all (self , obj : Union [Text , DiffSyncModel , Type [DiffSyncModel ]]) -> List [DiffSyncModel ]:
656
642
"""Get all objects of a given type.
@@ -661,12 +647,7 @@ def get_all(self, obj: Union[Text, DiffSyncModel, Type[DiffSyncModel]]) -> List[
661
647
Returns:
662
648
List[DiffSyncModel]: List of Object
663
649
"""
664
- if isinstance (obj , str ):
665
- modelname = obj
666
- else :
667
- modelname = obj .get_type ()
668
-
669
- return list (self ._data [modelname ].values ())
650
+ return self .store .get_all (model = obj )
670
651
671
652
def get_by_uids (
672
653
self , uids : List [Text ], obj : Union [Text , DiffSyncModel , Type [DiffSyncModel ]]
@@ -680,17 +661,7 @@ def get_by_uids(
680
661
Raises:
681
662
ObjectNotFound: if any of the requested UIDs are not found in the store
682
663
"""
683
- if isinstance (obj , str ):
684
- modelname = obj
685
- else :
686
- modelname = obj .get_type ()
687
-
688
- results = []
689
- for uid in uids :
690
- if uid not in self ._data [modelname ]:
691
- raise ObjectNotFound (f"{ modelname } { uid } not present in { self .name } " )
692
- results .append (self ._data [modelname ][uid ])
693
- return results
664
+ return self .store .get_by_uids (uids = uids , model = obj )
694
665
695
666
def add (self , obj : DiffSyncModel ):
696
667
"""Add a DiffSyncModel object to the store.
@@ -701,20 +672,18 @@ def add(self, obj: DiffSyncModel):
701
672
Raises:
702
673
ObjectAlreadyExists: if a different object with the same uid is already present.
703
674
"""
704
- modelname = obj .get_type ()
705
- uid = obj .get_unique_id ()
675
+ return self .store .add (obj = obj )
706
676
707
- existing_obj = self ._data [modelname ].get (uid )
708
- if existing_obj :
709
- if existing_obj is not obj :
710
- raise ObjectAlreadyExists (f"Object { uid } already present" , obj )
711
- # Return so we don't have to change anything on the existing object and underlying data
712
- return
677
+ def update (self , obj : DiffSyncModel ):
678
+ """Update a DiffSyncModel object to the store.
713
679
714
- if not obj . diffsync :
715
- obj . diffsync = self
680
+ Args :
681
+ obj (DiffSyncModel): Object to store
716
682
717
- self ._data [modelname ][uid ] = obj
683
+ Raises:
684
+ ObjectAlreadyExists: if a different object with the same uid is already present.
685
+ """
686
+ return self .store .update (obj = obj )
718
687
719
688
def remove (self , obj : DiffSyncModel , remove_children : bool = False ):
720
689
"""Remove a DiffSyncModel object from the store.
@@ -726,26 +695,7 @@ def remove(self, obj: DiffSyncModel, remove_children: bool = False):
726
695
Raises:
727
696
ObjectNotFound: if the object is not present
728
697
"""
729
- modelname = obj .get_type ()
730
- uid = obj .get_unique_id ()
731
-
732
- if uid not in self ._data [modelname ]:
733
- raise ObjectNotFound (f"{ modelname } { uid } not present in { self .name } " )
734
-
735
- if obj .diffsync is self :
736
- obj .diffsync = None
737
-
738
- del self ._data [modelname ][uid ]
739
-
740
- if remove_children :
741
- for child_type , child_fieldname in obj .get_children_mapping ().items ():
742
- for child_id in getattr (obj , child_fieldname ):
743
- try :
744
- child_obj = self .get (child_type , child_id )
745
- self .remove (child_obj , remove_children = remove_children )
746
- except ObjectNotFound :
747
- # Since this is "cleanup" code, log an error and continue, instead of letting the exception raise
748
- self ._log .error (f"Unable to remove child { child_id } of { modelname } { uid } - not found!" )
698
+ return self .store .remove (obj = obj , remove_children = remove_children )
749
699
750
700
def get_or_instantiate (
751
701
self , model : Type [DiffSyncModel ], ids : Dict , attrs : Dict = None
@@ -760,18 +710,7 @@ def get_or_instantiate(
760
710
Returns:
761
711
Tuple[DiffSyncModel, bool]: Provides the existing or new object and whether it was created or not.
762
712
"""
763
- created = False
764
- try :
765
- obj = self .get (model , ids )
766
- except ObjectNotFound :
767
- if not attrs :
768
- attrs = {}
769
- obj = model (** ids , ** attrs )
770
- # Add the object to diffsync adapter
771
- self .add (obj )
772
- created = True
773
-
774
- return obj , created
713
+ return self .store .get_or_instantiate (model = model , ids = ids , attrs = attrs )
775
714
776
715
def update_or_instantiate (self , model : Type [DiffSyncModel ], ids : Dict , attrs : Dict ) -> Tuple [DiffSyncModel , bool ]:
777
716
"""Attempt to update an existing object with provided ids/attrs or instantiate it with provided identifiers and attrs.
@@ -784,21 +723,18 @@ def update_or_instantiate(self, model: Type[DiffSyncModel], ids: Dict, attrs: Di
784
723
Returns:
785
724
Tuple[DiffSyncModel, bool]: Provides the existing or new object and whether it was created or not.
786
725
"""
787
- created = False
788
- try :
789
- obj = self .get (model , ids )
790
- except ObjectNotFound :
791
- obj = model (** ids , ** attrs )
792
- # Add the object to diffsync adapter
793
- self .add (obj )
794
- created = True
726
+ return self .store .update_or_instantiate (model = model , ids = ids , attrs = attrs )
795
727
796
- # Update existing obj with attrs
797
- for attr , value in attrs .items ():
798
- if getattr (obj , attr ) != value :
799
- setattr (obj , attr , value )
728
+ def count (self , model : Union [Text , "DiffSyncModel" , Type ["DiffSyncModel" ], None ] = None ):
729
+ """Count how many objects of one model type exist in the backend store.
800
730
801
- return obj , created
731
+ Args:
732
+ model (DiffSyncModel): The DiffSyncModel to check the number of elements. If not provided, default to all.
733
+
734
+ Returns:
735
+ Int: Number of elements of the model type
736
+ """
737
+ return self .store .count (model = model )
802
738
803
739
804
740
# DiffSyncModel references DiffSync and DiffSync references DiffSyncModel. Break the typing loop:
0 commit comments