Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Convert to attrs, remove runtype #723

Merged
merged 7 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions data_diff/abcs/compiler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from abc import ABC

import attrs


@attrs.define(frozen=False)
class AbstractCompiler(ABC):
pass


@attrs.define(frozen=False, eq=False)
class Compilable(ABC):
pass
54 changes: 38 additions & 16 deletions data_diff/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple, Union
from datetime import datetime

from runtype import dataclass
import attrs

from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown

Expand All @@ -13,55 +13,66 @@
DbTime = datetime


@dataclass
@attrs.define(frozen=True)
class ColType:
supported = True
@property
def supported(self) -> bool:
return True


@dataclass
@attrs.define(frozen=True)
class PrecisionType(ColType):
precision: int
rounds: Union[bool, Unknown] = Unknown


@attrs.define(frozen=True)
class Boolean(ColType):
precision = 0


@attrs.define(frozen=True)
class TemporalType(PrecisionType):
pass


@attrs.define(frozen=True)
class Timestamp(TemporalType):
pass


@attrs.define(frozen=True)
class TimestampTZ(TemporalType):
pass


@attrs.define(frozen=True)
class Datetime(TemporalType):
pass


@attrs.define(frozen=True)
class Date(TemporalType):
pass


@dataclass
@attrs.define(frozen=True)
class NumericType(ColType):
# 'precision' signifies how many fractional digits (after the dot) we want to compare
precision: int


@attrs.define(frozen=True)
class FractionalType(NumericType):
pass


@attrs.define(frozen=True)
class Float(FractionalType):
python_type = float


@attrs.define(frozen=True)
class IKey(ABC):
"Interface for ColType, for using a column as a key in table."

Expand All @@ -74,6 +85,7 @@ def make_value(self, value):
return self.python_type(value)


@attrs.define(frozen=True)
class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key
@property
def python_type(self) -> type:
Expand All @@ -82,27 +94,32 @@ def python_type(self) -> type:
return decimal.Decimal


@dataclass
@attrs.define(frozen=True)
class StringType(ColType):
python_type = str


@attrs.define(frozen=True)
class ColType_UUID(ColType, IKey):
python_type = ArithUUID


@attrs.define(frozen=True)
class ColType_Alphanum(ColType, IKey):
python_type = ArithAlphanumeric


@attrs.define(frozen=True)
class Native_UUID(ColType_UUID):
pass


@attrs.define(frozen=True)
class String_UUID(ColType_UUID, StringType):
pass


@attrs.define(frozen=True)
class String_Alphanum(ColType_Alphanum, StringType):
@staticmethod
def test_value(value: str) -> bool:
Expand All @@ -116,11 +133,12 @@ def make_value(self, value):
return self.python_type(value)


@attrs.define(frozen=True)
class String_VaryingAlphanum(String_Alphanum):
pass


@dataclass
@attrs.define(frozen=True)
class String_FixedAlphanum(String_Alphanum):
length: int

Expand All @@ -130,18 +148,20 @@ def make_value(self, value):
return self.python_type(value, max_len=self.length)


@dataclass
@attrs.define(frozen=True)
class Text(StringType):
supported = False
@property
def supported(self) -> bool:
return False


# In majority of DBMSes, it is called JSON/JSONB. Only in Snowflake, it is OBJECT.
@dataclass
@attrs.define(frozen=True)
class JSON(ColType):
pass


@dataclass
@attrs.define(frozen=True)
class Array(ColType):
item_type: ColType

Expand All @@ -151,22 +171,24 @@ class Array(ColType):
# For example, in BigQuery:
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals
@dataclass
@attrs.define(frozen=True)
class Struct(ColType):
pass


@dataclass
@attrs.define(frozen=True)
class Integer(NumericType, IKey):
precision: int = 0
python_type: type = int

def __post_init__(self):
def __attrs_post_init__(self):
assert self.precision == 0


@dataclass
@attrs.define(frozen=True)
class UnknownColType(ColType):
text: str

supported = False
@property
def supported(self) -> bool:
return False
10 changes: 10 additions & 0 deletions data_diff/abcs/mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from abc import ABC, abstractmethod

import attrs

from data_diff.abcs.database_types import (
Array,
TemporalType,
Expand All @@ -13,10 +16,12 @@
from data_diff.abcs.compiler import Compilable


@attrs.define(frozen=False)
class AbstractMixin(ABC):
"A mixin for a database dialect"


@attrs.define(frozen=False)
class AbstractMixin_NormalizeValue(AbstractMixin):
@abstractmethod
def to_comparable(self, value: str, coltype: ColType) -> str:
Expand Down Expand Up @@ -108,6 +113,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
return self.to_string(value)


@attrs.define(frozen=False)
class AbstractMixin_MD5(AbstractMixin):
"""Methods for calculating an MD6 hash as an integer."""

Expand All @@ -116,6 +122,7 @@ def md5_as_int(self, s: str) -> str:
"Provide SQL for computing md5 and returning an int"


@attrs.define(frozen=False)
class AbstractMixin_Schema(AbstractMixin):
"""Methods for querying the database schema

Expand All @@ -134,6 +141,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
"""


@attrs.define(frozen=False)
class AbstractMixin_RandomSample(AbstractMixin):
@abstractmethod
def random_sample_n(self, tbl: str, size: int) -> str:
Expand All @@ -151,6 +159,7 @@ def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str:
# """


@attrs.define(frozen=False)
class AbstractMixin_TimeTravel(AbstractMixin):
@abstractmethod
def time_travel(
Expand All @@ -173,6 +182,7 @@ def time_travel(
"""


@attrs.define(frozen=False)
class AbstractMixin_OptimizerHints(AbstractMixin):
@abstractmethod
def optimizer_hints(self, optimizer_hints: str) -> str:
Expand Down
6 changes: 3 additions & 3 deletions data_diff/cloud/datafold_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import base64
import dataclasses
import enum
import time
from typing import Any, Dict, List, Optional, Type, Tuple

import attrs
import pydantic
import requests
from typing_extensions import Self
Expand Down Expand Up @@ -178,13 +178,13 @@ class TCloudApiDataSourceTestResult(pydantic.BaseModel):
result: Optional[TCloudDataSourceTestResult]


@dataclasses.dataclass
@attrs.define(frozen=True)
class DatafoldAPI:
api_key: str
host: str = "https://app.datafold.com"
timeout: int = 30

def __post_init__(self):
def __attrs_post_init__(self):
self.host = self.host.rstrip("/")
self.headers = {
"Authorization": f"Key {self.api_key}",
Expand Down
10 changes: 8 additions & 2 deletions data_diff/databases/_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from itertools import zip_longest
from contextlib import suppress
import weakref

import attrs
import dsnparse
import toml

from runtype import dataclass
from typing_extensions import Self

from data_diff.databases.base import Database, ThreadedDatabase
Expand All @@ -25,7 +26,7 @@
from data_diff.databases.mssql import MsSQL


@dataclass
@attrs.define(frozen=True)
class MatchUriPath:
database_cls: Type[Database]

Expand Down Expand Up @@ -92,12 +93,16 @@ def match_path(self, dsn):
}


@attrs.define(frozen=False, init=False)
class Connect:
"""Provides methods for connecting to a supported database using a URL or connection dict."""

database_by_scheme: Dict[str, Database]
match_uri_path: Dict[str, MatchUriPath]
conn_cache: MutableMapping[Hashable, Database]

def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
super().__init__()
self.database_by_scheme = database_by_scheme
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
self.conn_cache = weakref.WeakValueDictionary()
Expand Down Expand Up @@ -284,6 +289,7 @@ def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable:
return db_conf


@attrs.define(frozen=False, init=False)
class Connect_SetUTC(Connect):
"""Provides methods for connecting to a supported database using a URL or connection dict.

Expand Down
Loading