Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Convert to attrs, remove runtype #720

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions data_diff/abcs/compiler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from abc import ABC

import attrs


@attrs.define
class AbstractCompiler(ABC):
pass


@attrs.define(eq=False)
class Compilable(ABC):
pass
55 changes: 39 additions & 16 deletions data_diff/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple, Union
from datetime import datetime

from runtype import dataclass
import attrs

from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown

Expand All @@ -13,55 +13,67 @@
DbTime = datetime


@dataclass
@attrs.define
class ColType:
supported = True
@property
def supported(self) -> bool:
return True


@dataclass
@attrs.define
class PrecisionType(ColType):
precision: int
rounds: Union[bool, Unknown] = Unknown



Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change

@attrs.define
class Boolean(ColType):
precision = 0


@attrs.define
class TemporalType(PrecisionType):
pass


@attrs.define
class Timestamp(TemporalType):
pass


@attrs.define
class TimestampTZ(TemporalType):
pass


@attrs.define
class Datetime(TemporalType):
pass


@attrs.define
class Date(TemporalType):
pass


@dataclass
@attrs.define
class NumericType(ColType):
# 'precision' signifies how many fractional digits (after the dot) we want to compare
precision: int


@attrs.define
class FractionalType(NumericType):
pass


@attrs.define
class Float(FractionalType):
python_type = float


@attrs.define
class IKey(ABC):
"Interface for ColType, for using a column as a key in table."

Expand All @@ -74,6 +86,7 @@ def make_value(self, value):
return self.python_type(value)


@attrs.define
class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key
@property
def python_type(self) -> type:
Expand All @@ -82,27 +95,32 @@ def python_type(self) -> type:
return decimal.Decimal


@dataclass
@attrs.define
class StringType(ColType):
python_type = str


@attrs.define
class ColType_UUID(ColType, IKey):
python_type = ArithUUID


@attrs.define
class ColType_Alphanum(ColType, IKey):
python_type = ArithAlphanumeric


@attrs.define
class Native_UUID(ColType_UUID):
pass


@attrs.define
class String_UUID(ColType_UUID, StringType):
pass


@attrs.define
class String_Alphanum(ColType_Alphanum, StringType):
@staticmethod
def test_value(value: str) -> bool:
Expand All @@ -116,11 +134,12 @@ def make_value(self, value):
return self.python_type(value)


@attrs.define
class String_VaryingAlphanum(String_Alphanum):
pass


@dataclass
@attrs.define
class String_FixedAlphanum(String_Alphanum):
length: int

Expand All @@ -130,18 +149,20 @@ def make_value(self, value):
return self.python_type(value, max_len=self.length)


@dataclass
@attrs.define
class Text(StringType):
supported = False
@property
def supported(self) -> bool:
return False


# In majority of DBMSes, it is called JSON/JSONB. Only in Snowflake, it is OBJECT.
@dataclass
@attrs.define
class JSON(ColType):
pass


@dataclass
@attrs.define
class Array(ColType):
item_type: ColType

Expand All @@ -151,22 +172,24 @@ class Array(ColType):
# For example, in BigQuery:
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals
@dataclass
@attrs.define
class Struct(ColType):
pass


@dataclass
@attrs.define
class Integer(NumericType, IKey):
precision: int = 0
python_type: type = int

def __post_init__(self):
def __attrs_post_init__(self):
assert self.precision == 0


@dataclass
@attrs.define
class UnknownColType(ColType):
text: str

supported = False
@property
def supported(self) -> bool:
return False
10 changes: 10 additions & 0 deletions data_diff/abcs/mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from abc import ABC, abstractmethod

import attrs

from data_diff.abcs.database_types import (
Array,
TemporalType,
Expand All @@ -13,10 +16,12 @@
from data_diff.abcs.compiler import Compilable


@attrs.define
class AbstractMixin(ABC):
"A mixin for a database dialect"


@attrs.define
class AbstractMixin_NormalizeValue(AbstractMixin):
@abstractmethod
def to_comparable(self, value: str, coltype: ColType) -> str:
Expand Down Expand Up @@ -108,6 +113,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
return self.to_string(value)


@attrs.define
class AbstractMixin_MD5(AbstractMixin):
"""Methods for calculating an MD6 hash as an integer."""

Expand All @@ -116,6 +122,7 @@ def md5_as_int(self, s: str) -> str:
"Provide SQL for computing md5 and returning an int"


@attrs.define
class AbstractMixin_Schema(AbstractMixin):
"""Methods for querying the database schema

Expand All @@ -134,6 +141,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
"""


@attrs.define
class AbstractMixin_RandomSample(AbstractMixin):
@abstractmethod
def random_sample_n(self, tbl: str, size: int) -> str:
Expand All @@ -151,6 +159,7 @@ def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str:
# """


@attrs.define
class AbstractMixin_TimeTravel(AbstractMixin):
@abstractmethod
def time_travel(
Expand All @@ -173,6 +182,7 @@ def time_travel(
"""


@attrs.define
class AbstractMixin_OptimizerHints(AbstractMixin):
@abstractmethod
def optimizer_hints(self, optimizer_hints: str) -> str:
Expand Down
6 changes: 3 additions & 3 deletions data_diff/cloud/datafold_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import base64
import dataclasses
import enum
import time
from typing import Any, Dict, List, Optional, Type, Tuple

import attrs
import pydantic
import requests
from typing_extensions import Self
Expand Down Expand Up @@ -178,13 +178,13 @@ class TCloudApiDataSourceTestResult(pydantic.BaseModel):
result: Optional[TCloudDataSourceTestResult]


@dataclasses.dataclass
@attrs.define
class DatafoldAPI:
api_key: str
host: str = "https://app.datafold.com"
timeout: int = 30

def __post_init__(self):
def __attrs_post_init__(self):
self.host = self.host.rstrip("/")
self.headers = {
"Authorization": f"Key {self.api_key}",
Expand Down
9 changes: 7 additions & 2 deletions data_diff/databases/_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from itertools import zip_longest
from contextlib import suppress
import weakref

import attrs
import dsnparse
import toml

from runtype import dataclass
from typing_extensions import Self

from data_diff.databases.base import Database, ThreadedDatabase
Expand All @@ -25,7 +26,7 @@
from data_diff.databases.mssql import MsSQL


@dataclass
@attrs.define
class MatchUriPath:
database_cls: Type[Database]

Expand Down Expand Up @@ -92,8 +93,11 @@ def match_path(self, dsn):
}


@attrs.define(init=False)
class Connect:
"""Provides methods for connecting to a supported database using a URL or connection dict."""
database_by_scheme: Dict[str, Database]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change
database_by_scheme: Dict[str, Database]
database_by_scheme: Dict[str, Database]

match_uri_path: Dict[str, MatchUriPath]
conn_cache: MutableMapping[Hashable, Database]

def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
Expand Down Expand Up @@ -283,6 +287,7 @@ def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable:
return db_conf


@attrs.define(init=False)
class Connect_SetUTC(Connect):
"""Provides methods for connecting to a supported database using a URL or connection dict.

Expand Down
Loading