Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Simplify by replacing the self-made WeakCache with the builtin WeakValueDict #703

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions data_diff/sqeleton/databases/_connect.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Type, Optional, Union, Dict
from typing import Hashable, MutableMapping, Type, Optional, Union, Dict
from itertools import zip_longest
from contextlib import suppress
import weakref
import dsnparse
import toml

from runtype import dataclass
from typing_extensions import Self

from ..abcs.mixins import AbstractMixin
from ..utils import WeakCache
from .base import Database, ThreadedDatabase
from .postgresql import PostgreSQL
from .mysql import MySQL
Expand Down Expand Up @@ -94,11 +94,12 @@ def match_path(self, dsn):

class Connect:
"""Provides methods for connecting to a supported database using a URL or connection dict."""
conn_cache: MutableMapping[Hashable, Database]

def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
self.database_by_scheme = database_by_scheme
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
self.conn_cache = WeakCache()
self.conn_cache = weakref.WeakValueDictionary()

def for_databases(self, *dbs) -> Self:
database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs}
Expand Down Expand Up @@ -263,9 +264,10 @@ def __call__(
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
<data_diff.sqeleton.databases.mysql.MySQL object at ...>
"""
cache_key = self.__make_cache_key(db_conf)
if shared:
with suppress(KeyError):
conn = self.conn_cache.get(db_conf)
conn = self.conn_cache[cache_key]
if not conn.is_closed:
return conn

Expand All @@ -277,5 +279,10 @@ def __call__(
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")

if shared:
self.conn_cache.add(db_conf, conn)
self.conn_cache[cache_key] = conn
return conn

def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable:
if isinstance(db_conf, dict):
return tuple(db_conf.items())
return db_conf
26 changes: 0 additions & 26 deletions data_diff/sqeleton/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
Any,
Sequence,
Dict,
Hashable,
TypeVar,
List,
)
from abc import abstractmethod
from weakref import ref
import math
import string
import re
Expand All @@ -24,30 +22,6 @@
# -- Common --


class WeakCache:
def __init__(self):
self._cache = {}

def _hashable_key(self, k: Union[dict, Hashable]) -> Hashable:
if isinstance(k, dict):
return tuple(k.items())
return k

def add(self, key: Union[dict, Hashable], value: Any):
key = self._hashable_key(key)
self._cache[key] = ref(value)

def get(self, key: Union[dict, Hashable]) -> Any:
key = self._hashable_key(key)

value = self._cache[key]()
if value is None:
del self._cache[key]
raise KeyError(f"Key {key} not found, or no longer a valid reference")

return value


def join_iter(joiner: Any, iterable: Iterable) -> Iterable:
it = iter(iterable)
try:
Expand Down
23 changes: 1 addition & 22 deletions tests/sqeleton/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from data_diff.sqeleton.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human, WeakCache
from data_diff.sqeleton.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human


class TestUtils(unittest.TestCase):
Expand Down Expand Up @@ -81,24 +81,3 @@ def test_number_to_human(self):
assert number_to_human(-1000) == "-1k"
assert number_to_human(-1000000) == "-1m"
assert number_to_human(-1000000000) == "-1b"

def test_weak_cache(self):
# Create cache
cache = WeakCache()

# Test adding and retrieving basic value
o = {1, 2}
cache.add("key", o)
assert cache.get("key") is o

# Test adding and retrieving dict value
cache.add({"key": "value"}, o)
assert cache.get({"key": "value"}) is o

# Test deleting value when reference is lost
del o
try:
cache.get({"key": "value"})
assert False, "KeyError should have been raised"
except KeyError:
pass