diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py index aee220dd..464a71f4 100644 --- a/data_diff/sqeleton/databases/_connect.py +++ b/data_diff/sqeleton/databases/_connect.py @@ -1,6 +1,7 @@ -from typing import Type, Optional, Union, Dict +from typing import Hashable, MutableMapping, Type, Optional, Union, Dict from itertools import zip_longest from contextlib import suppress +import weakref import dsnparse import toml @@ -8,7 +9,6 @@ from typing_extensions import Self from ..abcs.mixins import AbstractMixin -from ..utils import WeakCache from .base import Database, ThreadedDatabase from .postgresql import PostgreSQL from .mysql import MySQL @@ -94,11 +94,12 @@ def match_path(self, dsn): class Connect: """Provides methods for connecting to a supported database using a URL or connection dict.""" + conn_cache: MutableMapping[Hashable, Database] def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): self.database_by_scheme = database_by_scheme self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()} - self.conn_cache = WeakCache() + self.conn_cache = weakref.WeakValueDictionary() def for_databases(self, *dbs) -> Self: database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs} @@ -263,9 +264,10 @@ def __call__( >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) """ + cache_key = self.__make_cache_key(db_conf) if shared: with suppress(KeyError): - conn = self.conn_cache.get(db_conf) + conn = self.conn_cache[cache_key] if not conn.is_closed: return conn @@ -277,5 +279,10 @@ def __call__( raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") if shared: - self.conn_cache.add(db_conf, conn) + self.conn_cache[cache_key] = conn return conn + + def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable: + if isinstance(db_conf, dict): + return tuple(db_conf.items()) + return db_conf diff --git a/data_diff/sqeleton/utils.py b/data_diff/sqeleton/utils.py index d356d18b..d1f596db 100644 --- a/data_diff/sqeleton/utils.py +++ b/data_diff/sqeleton/utils.py @@ -7,12 +7,10 @@ Any, Sequence, Dict, - Hashable, TypeVar, List, ) from abc import abstractmethod -from weakref import ref import math import string import re @@ -24,30 +22,6 @@ # -- Common -- -class WeakCache: - def __init__(self): - self._cache = {} - - def _hashable_key(self, k: Union[dict, Hashable]) -> Hashable: - if isinstance(k, dict): - return tuple(k.items()) - return k - - def add(self, key: Union[dict, Hashable], value: Any): - key = self._hashable_key(key) - self._cache[key] = ref(value) - - def get(self, key: Union[dict, Hashable]) -> Any: - key = self._hashable_key(key) - - value = self._cache[key]() - if value is None: - del self._cache[key] - raise KeyError(f"Key {key} not found, or no longer a valid reference") - - return value - - def join_iter(joiner: Any, iterable: Iterable) -> Iterable: it = iter(iterable) try: diff --git a/tests/sqeleton/test_utils.py b/tests/sqeleton/test_utils.py index 25ec9c39..973121a2 100644 --- a/tests/sqeleton/test_utils.py +++ b/tests/sqeleton/test_utils.py @@ -1,6 +1,6 @@ import unittest -from data_diff.sqeleton.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human, WeakCache +from data_diff.sqeleton.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human class TestUtils(unittest.TestCase): @@ -81,24 +81,3 @@ def test_number_to_human(self): assert number_to_human(-1000) == "-1k" assert number_to_human(-1000000) == "-1m" assert number_to_human(-1000000000) == "-1b" - - def test_weak_cache(self): - # Create cache - cache = WeakCache() - - # Test adding and retrieving basic value - o = {1, 2} - cache.add("key", o) - assert cache.get("key") is o - - # Test adding and retrieving dict value - cache.add({"key": "value"}, o) - assert cache.get({"key": "value"}) is o - - # Test deleting value when reference is lost - del o - try: - cache.get({"key": "value"}) - assert False, "KeyError should have been raised" - except KeyError: - pass