Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 5ff7ecc

Browse files
author
Sergey Vasilyev
committed
Annotate missing fields
1 parent 8e633de commit 5ff7ecc

17 files changed

+77
-29
lines changed

data_diff/databases/_connect.py

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def match_path(self, dsn):
9595
class Connect:
9696
"""Provides methods for connecting to a supported database using a URL or connection dict."""
9797

98+
database_by_scheme: Dict[str, Database]
99+
match_uri_path: Dict[str, MatchUriPath]
98100
conn_cache: MutableMapping[Hashable, Database]
99101

100102
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):

data_diff/databases/base.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import sys
77
import logging
8-
from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar
8+
from typing import Any, Callable, ClassVar, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar
99
from functools import partial, wraps
1010
from concurrent.futures import ThreadPoolExecutor
1111
import threading
@@ -179,6 +179,9 @@ class ThreadLocalInterpreter:
179179
Useful for cursor-sensitive operations, such as creating a temporary table.
180180
"""
181181

182+
compiler: Compiler
183+
gen: Generator
184+
182185
def __init__(self, compiler: Compiler, gen: Generator):
183186
super().__init__()
184187
self.gen = gen
@@ -238,9 +241,9 @@ def optimizer_hints(self, hints: str) -> str:
238241

239242

240243
class BaseDialect(abc.ABC):
241-
SUPPORTS_PRIMARY_KEY = False
242-
SUPPORTS_INDEXES = False
243-
TYPE_CLASSES: Dict[str, type] = {}
244+
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False
245+
SUPPORTS_INDEXES: ClassVar[bool] = False
246+
TYPE_CLASSES: ClassVar[Dict[str, type]] = {}
244247

245248
PLACEHOLDER_TABLE = None # Used for Oracle
246249

@@ -835,14 +838,13 @@ class Database(abc.ABC, _RuntypeHackToFixCicularRefrencedDatabase):
835838
Instanciated using :meth:`~data_diff.connect`
836839
"""
837840

838-
default_schema: str = None
839-
SUPPORTS_ALPHANUMS = True
840-
SUPPORTS_UNIQUE_CONSTAINT = False
841-
842-
CONNECT_URI_KWPARAMS = []
841+
SUPPORTS_ALPHANUMS: ClassVar[bool] = True
842+
SUPPORTS_UNIQUE_CONSTAINT: ClassVar[bool] = False
843+
CONNECT_URI_KWPARAMS: ClassVar[List[str]] = []
843844

844-
_interactive = False
845-
is_closed = False
845+
default_schema: Optional[str] = None
846+
_interactive: bool = False
847+
is_closed: bool = False
846848

847849
@property
848850
def name(self):
@@ -1109,6 +1111,10 @@ class ThreadedDatabase(Database):
11091111
Used for database connectors that do not support sharing their connection between different threads.
11101112
"""
11111113

1114+
_init_error: Optional[Exception]
1115+
_queue: ThreadPoolExecutor
1116+
thread_local: threading.local
1117+
11121118
def __init__(self, thread_count=1):
11131119
super().__init__()
11141120
self._init_error = None

data_diff/databases/bigquery.py

+4
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ class BigQuery(Database):
223223
CONNECT_URI_PARAMS = ["dataset"]
224224
dialect = Dialect()
225225

226+
project: str
227+
dataset: str
228+
_client: Any
229+
226230
def __init__(self, project, *, dataset, bigquery_credentials=None, **kw):
227231
super().__init__()
228232
credentials = bigquery_credentials

data_diff/databases/clickhouse.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Type
1+
from typing import Any, Dict, Optional, Type
22

33
from data_diff.databases.base import (
44
MD5_HEXDIGITS,
@@ -167,6 +167,8 @@ class Clickhouse(ThreadedDatabase):
167167
CONNECT_URI_HELP = "clickhouse://<user>:<password>@<host>/<database>"
168168
CONNECT_URI_PARAMS = ["database?"]
169169

170+
_args: Dict[str, Any]
171+
170172
def __init__(self, *, thread_count: int, **kw):
171173
super().__init__(thread_count=thread_count)
172174

data_diff/databases/databricks.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Dict, Sequence
2+
from typing import Any, Dict, Sequence
33
import logging
44

55
from data_diff.abcs.database_types import (
@@ -103,13 +103,16 @@ class Databricks(ThreadedDatabase):
103103
CONNECT_URI_HELP = "databricks://:<access_token>@<server_hostname>/<http_path>"
104104
CONNECT_URI_PARAMS = ["catalog", "schema"]
105105

106+
catalog: str
107+
_args: Dict[str, Any]
108+
106109
def __init__(self, *, thread_count, **kw):
107110
super().__init__(thread_count=thread_count)
108111
logging.getLogger("databricks.sql").setLevel(logging.WARNING)
109112

110113
self._args = kw
111114
self.default_schema = kw.get("schema", "default")
112-
self.catalog = self._args.get("catalog", "hive_metastore")
115+
self.catalog = kw.get("catalog", "hive_metastore")
113116

114117
def create_connection(self):
115118
databricks = import_databricks()

data_diff/databases/duckdb.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Any, Dict, Union
22

33
from data_diff.utils import match_regexps
44
from data_diff.abcs.database_types import (
@@ -134,14 +134,17 @@ def current_timestamp(self) -> str:
134134
class DuckDB(Database):
135135
dialect = Dialect()
136136
SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it
137-
default_schema = "main"
138137
CONNECT_URI_HELP = "duckdb://<dbname>@<filepath>"
139138
CONNECT_URI_PARAMS = ["database", "dbpath"]
140139

140+
_args: Dict[str, Any]
141+
_conn: Any
142+
141143
def __init__(self, **kw):
142144
super().__init__()
143145
self._args = kw
144146
self._conn = self.create_connection()
147+
self.default_schema = "main"
145148

146149
@property
147150
def is_autocommit(self) -> bool:

data_diff/databases/mssql.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Any, Dict, Optional
22
from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
33
from data_diff.databases.base import (
44
CHECKSUM_HEXDIGITS,
@@ -160,10 +160,12 @@ def constant_values(self, rows) -> str:
160160

161161
class MsSQL(ThreadedDatabase):
162162
dialect = Dialect()
163-
#
164163
CONNECT_URI_HELP = "mssql://<user>:<password>@<host>/<database>/<schema>"
165164
CONNECT_URI_PARAMS = ["database", "schema"]
166165

166+
default_database: str
167+
_args: Dict[str, Any]
168+
167169
def __init__(self, host, port, user, password, *, database, thread_count, **kw):
168170
args = dict(server=host, port=port, database=database, user=user, password=password, **kw)
169171
self._args = {k: v for k, v in args.items() if v is not None}

data_diff/databases/mysql.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, Dict
2+
13
from data_diff.abcs.database_types import (
24
Datetime,
35
Timestamp,
@@ -137,6 +139,8 @@ class MySQL(ThreadedDatabase):
137139
CONNECT_URI_HELP = "mysql://<user>:<password>@<host>/<database>"
138140
CONNECT_URI_PARAMS = ["database?"]
139141

142+
_args: Dict[str, Any]
143+
140144
def __init__(self, *, thread_count, **kw):
141145
super().__init__(thread_count=thread_count)
142146
self._args = kw

data_diff/databases/oracle.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional
1+
from typing import Any, Dict, List, Optional
22

33
from data_diff.utils import match_regexps
44
from data_diff.abcs.database_types import (
@@ -181,6 +181,8 @@ class Oracle(ThreadedDatabase):
181181
CONNECT_URI_HELP = "oracle://<user>:<password>@<host>/<database>"
182182
CONNECT_URI_PARAMS = ["database?"]
183183

184+
kwargs: Dict[str, Any]
185+
184186
def __init__(self, *, host, database, thread_count, **kw):
185187
super().__init__(thread_count=thread_count)
186188
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)

data_diff/databases/postgresql.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import List
1+
from typing import Any, ClassVar, Dict, List, Type
22
from data_diff.abcs.database_types import (
3+
ColType,
34
DbPath,
45
JSON,
56
Timestamp,
@@ -68,7 +69,7 @@ class PostgresqlDialect(
6869
SUPPORTS_PRIMARY_KEY = True
6970
SUPPORTS_INDEXES = True
7071

71-
TYPE_CLASSES = {
72+
TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {
7273
# Timestamps
7374
"timestamp with time zone": TimestampTZ,
7475
"timestamp without time zone": Timestamp,
@@ -125,11 +126,12 @@ class PostgreSQL(ThreadedDatabase):
125126
CONNECT_URI_HELP = "postgresql://<user>:<password>@<host>/<database>"
126127
CONNECT_URI_PARAMS = ["database?"]
127128

128-
default_schema = "public"
129+
_args: Dict[str, Any]
129130

130131
def __init__(self, *, thread_count, **kw):
131132
super().__init__(thread_count=thread_count)
132133
self._args = kw
134+
self.default_schema = "public"
133135

134136
def create_connection(self):
135137
if not self._args:

data_diff/databases/presto.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import partial
22
import re
3+
from typing import Any
34

45
from data_diff.utils import match_regexps
56

@@ -158,10 +159,11 @@ class Presto(Database):
158159
CONNECT_URI_HELP = "presto://<user>@<host>/<catalog>/<schema>"
159160
CONNECT_URI_PARAMS = ["catalog", "schema"]
160161

161-
default_schema = "public"
162+
_conn: Any
162163

163164
def __init__(self, **kw):
164165
super().__init__()
166+
self.default_schema = "public"
165167
prestodb = import_presto()
166168

167169
if kw.get("schema"):

data_diff/databases/redshift.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import List, Dict
1+
from typing import ClassVar, List, Dict, Type
22
from data_diff.abcs.database_types import (
3+
ColType,
34
Float,
45
JSON,
56
TemporalType,
@@ -53,7 +54,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str:
5354

5455
class Dialect(PostgresqlDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
5556
name = "Redshift"
56-
TYPE_CLASSES = {
57+
TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {
5758
**PostgresqlDialect.TYPE_CLASSES,
5859
"double": Float,
5960
"real": Float,

data_diff/databases/snowflake.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union, List
1+
from typing import Any, Union, List
22
import logging
33

44
from data_diff.abcs.database_types import (
@@ -154,6 +154,8 @@ class Snowflake(Database):
154154
CONNECT_URI_PARAMS = ["database", "schema"]
155155
CONNECT_URI_KWPARAMS = ["warehouse"]
156156

157+
_conn: Any
158+
157159
def __init__(self, *, schema: str, **kw):
158160
super().__init__()
159161
snowflake, serialization, default_backend = import_snowflake()

data_diff/databases/trino.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
24
from data_diff.abcs.database_types import TemporalType, ColType_UUID
35
from data_diff.databases import presto
@@ -39,6 +41,8 @@ class Trino(presto.Presto):
3941
CONNECT_URI_HELP = "trino://<user>@<host>/<catalog>/<schema>"
4042
CONNECT_URI_PARAMS = ["catalog", "schema"]
4143

44+
_conn: Any
45+
4246
def __init__(self, **kw):
4347
super().__init__()
4448
trino = import_trino()

data_diff/databases/vertica.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import Any, Dict, List
22

33
from data_diff.utils import match_regexps
44
from data_diff.databases.base import (
@@ -156,12 +156,13 @@ class Vertica(ThreadedDatabase):
156156
CONNECT_URI_HELP = "vertica://<user>:<password>@<host>/<database>"
157157
CONNECT_URI_PARAMS = ["database?"]
158158

159-
default_schema = "public"
159+
_args: Dict[str, Any]
160160

161161
def __init__(self, *, thread_count, **kw):
162162
super().__init__(thread_count=thread_count)
163163
self._args = kw
164164
self._args["AUTOCOMMIT"] = False
165+
self.default_schema = "public"
165166

166167
def create_connection(self):
167168
vertica = import_vertica()

data_diff/thread_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ class ThreadedYielder(Iterable):
4545
Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first)
4646
"""
4747

48+
_pool: ThreadPoolExecutor
49+
_futures: deque
50+
_yield: deque
51+
_exception: Optional[None]
52+
4853
def __init__(self, max_workers: Optional[int] = None):
4954
super().__init__()
5055
self._pool = PriorityThreadPoolExecutor(max_workers)

data_diff/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
import string
66
from abc import abstractmethod
7-
from typing import Any, Dict, Iterable, Iterator, List, MutableMapping, Sequence, TypeVar, Union
7+
from typing import Any, Dict, Iterable, Iterator, List, MutableMapping, Optional, Sequence, TypeVar, Union
88
from urllib.parse import urlparse
99
import operator
1010
import threading
@@ -175,6 +175,9 @@ def alphanums_to_numbers(s1: str, s2: str):
175175

176176

177177
class ArithAlphanumeric(ArithString):
178+
_str: str
179+
_max_len: Optional[int]
180+
178181
def __init__(self, s: str, max_len=None):
179182
super().__init__()
180183

0 commit comments

Comments
 (0)