forked from datafold/data-diff
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompiler.py
87 lines (65 loc) · 2.72 KB
/
compiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import random
from dataclasses import field
from datetime import datetime
from typing import Any, Dict, Sequence, List
from runtype import dataclass
from typing_extensions import Self
from ..utils import ArithString
from ..abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable
import contextvars
cv_params = contextvars.ContextVar("params")
class CompileError(Exception):
pass
class Root:
"Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)"
@dataclass
class Compiler(AbstractCompiler):
database: AbstractDatabase
params: dict = field(default_factory=dict)
in_select: bool = False # Compilation runtime flag
in_join: bool = False # Compilation runtime flag
_table_context: List = field(default_factory=list) # List[ITable]
_subqueries: Dict[str, Any] = field(default_factory=dict) # XXX not thread-safe
root: bool = True
_counter: List = field(default_factory=lambda: [0])
@property
def dialect(self) -> AbstractDialect:
return self.database.dialect
def compile(self, elem, params=None) -> str:
if params:
cv_params.set(params)
if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root):
from .ast_classes import Select
elem = Select(columns=[elem])
res = self._compile(elem)
if self.root and self._subqueries:
subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items())
self._subqueries.clear()
return f"WITH {subq}\n{res}"
return res
def _compile(self, elem) -> str:
if elem is None:
return "NULL"
elif isinstance(elem, Compilable):
return elem.compile(self.replace(root=False))
elif isinstance(elem, str):
return f"'{elem}'"
elif isinstance(elem, (int, float)):
return str(elem)
elif isinstance(elem, datetime):
return self.dialect.timestamp_value(elem)
elif isinstance(elem, bytes):
return f"b'{elem.decode()}'"
elif isinstance(elem, ArithString):
return f"'{elem}'"
assert False, elem
def new_unique_name(self, prefix="tmp"):
self._counter[0] += 1
return f"{prefix}{self._counter[0]}"
def new_unique_table_name(self, prefix="tmp") -> DbPath:
self._counter[0] += 1
return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}")
def add_table_context(self, *tables: Sequence, **kw) -> Self:
return self.replace(_table_context=self._table_context + list(tables), **kw)
def quote(self, s: str):
return self.dialect.quote(s)