This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 279
/
Copy pathdatabricks.py
234 lines (187 loc) · 8.01 KB
/
databricks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import math
from typing import Any, ClassVar, Dict, Sequence, Type
import logging
import attrs
from data_diff.abcs.database_types import (
Integer,
Float,
Decimal,
Timestamp,
Text,
TemporalType,
NumericType,
DbPath,
ColType,
UnknownColType,
Boolean,
)
from data_diff.databases.base import (
MD5_HEXDIGITS,
CHECKSUM_HEXDIGITS,
CHECKSUM_OFFSET,
BaseDialect,
ThreadedDatabase,
import_helper,
parse_table_name,
)
@import_helper(text="You can install it using 'pip install databricks-sql-connector'")
def import_databricks():
import databricks.sql
return databricks
@attrs.define(frozen=False)
class Dialect(BaseDialect):
name = "Databricks"
ROUNDS_ON_PREC_LOSS = True
TYPE_CLASSES = {
# Numbers
"INT": Integer,
"SMALLINT": Integer,
"TINYINT": Integer,
"BIGINT": Integer,
"FLOAT": Float,
"DOUBLE": Float,
"DECIMAL": Decimal,
# Timestamps
"TIMESTAMP": Timestamp,
"TIMESTAMP_NTZ": Timestamp,
# Text
"STRING": Text,
"VARCHAR": Text,
# Boolean
"BOOLEAN": Boolean,
}
def type_repr(self, t) -> str:
try:
return {str: "STRING"}[t]
except KeyError:
return super().type_repr(t)
def quote(self, s: str):
return f"`{s}`"
def to_string(self, s: str) -> str:
return f"cast({s} as string)"
def _convert_db_precision_to_digits(self, p: int) -> int:
# Subtracting 2 due to wierd precision issues
return max(super()._convert_db_precision_to_digits(p) - 2, 0)
def set_timezone_to_utc(self) -> str:
return "SET TIME ZONE 'UTC'"
def parse_table_name(self, name: str) -> DbPath:
path = parse_table_name(name)
return tuple(i for i in path if i is not None)
def md5_as_int(self, s: str) -> str:
return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0)) - {CHECKSUM_OFFSET}"
def md5_as_hex(self, s: str) -> str:
return f"md5({s})"
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
"""Databricks timestamp contains no more than 6 digits in precision"""
if coltype.rounds:
# cast to timestamp due to unix_micros() requiring timestamp
timestamp = f"cast(round(unix_micros(cast({value} as timestamp)) / 1000000, {coltype.precision}) * 1000000 as bigint)"
return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')"
precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision)
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"
def normalize_number(self, value: str, coltype: NumericType) -> str:
value = f"cast({value} as decimal(38, {coltype.precision}))"
if coltype.precision > 0:
value = f"format_number({value}, {coltype.precision})"
return f"replace({self.to_string(value)}, ',', '')"
def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
return self.to_string(f"cast ({value} as int)")
@attrs.define(frozen=False, init=False, kw_only=True)
class Databricks(ThreadedDatabase):
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
CONNECT_URI_HELP = "databricks://:<access_token>@<server_hostname>/<http_path>"
CONNECT_URI_PARAMS = ["catalog", "schema"]
catalog: str
_args: Dict[str, Any]
def __init__(self, *, thread_count, **kw):
super().__init__(thread_count=thread_count)
logging.getLogger("databricks.sql").setLevel(logging.WARNING)
self._args = kw
self.default_schema = kw.get("schema", "default")
self.catalog = kw.get("catalog", "hive_metastore")
def create_connection(self):
databricks = import_databricks()
try:
return databricks.sql.connect(
server_hostname=self._args["server_hostname"],
http_path=self._args["http_path"],
access_token=self._args["access_token"],
catalog=self.catalog,
)
except databricks.sql.exc.Error as e:
raise ConnectionError(*e.args) from e
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
conn = self.create_connection()
table_schema = {}
try:
table_schema = super().query_table_schema(path)
except:
logging.warning("Failed to get schema from information_schema, falling back to legacy approach.")
if not table_schema:
# This legacy approach can cause bugs. e.g. VARCHAR(255) -> VARCHAR(255)
# and not the expected VARCHAR
# I don't think we'll fall back to this approach, but if so, see above
catalog, schema, table = self._normalize_table_path(path)
with conn.cursor() as cursor:
cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table)
try:
rows = cursor.fetchall()
finally:
conn.close()
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
table_schema = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows}
assert len(table_schema) == len(rows)
return table_schema
else:
return table_schema
def select_table_schema(self, path: DbPath) -> str:
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
database, schema, name = self._normalize_table_path(path)
info_schema_path = ["information_schema", "columns"]
if database:
info_schema_path.insert(0, database)
return (
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
f"FROM {'.'.join(info_schema_path)} "
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
)
def _process_table_schema(
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
):
accept = {i.lower() for i in filter_columns}
rows = [row for name, row in raw_schema.items() if name.lower() in accept]
resulted_rows = []
for row in rows:
row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1]
type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType)
if issubclass(type_cls, Integer):
row = (row[0], row_type, None, None, 0)
elif issubclass(type_cls, Float):
numeric_precision = math.ceil(row[2] / math.log(2, 10))
row = (row[0], row_type, None, numeric_precision, None)
elif issubclass(type_cls, Decimal):
items = row[1][8:].rstrip(")").split(",")
numeric_precision, numeric_scale = int(items[0]), int(items[1])
row = (row[0], row_type, None, numeric_precision, numeric_scale)
elif issubclass(type_cls, Timestamp):
row = (row[0], row_type, row[2], None, None)
else:
row = (row[0], row_type, None, None, None)
resulted_rows.append(row)
col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows}
self._refine_coltypes(path, col_dict, where)
return col_dict
@property
def is_autocommit(self) -> bool:
return True
def _normalize_table_path(self, path: DbPath) -> DbPath:
if len(path) == 1:
return self.catalog, self.default_schema, path[0]
elif len(path) == 2:
return self.catalog, path[0], path[1]
elif len(path) == 3:
return path
raise ValueError(
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or catalog.schema.table"
)