1
1
import abc
2
2
import functools
3
- from dataclasses import field
3
+ import random
4
4
from datetime import datetime
5
5
import math
6
6
import sys
14
14
import decimal
15
15
import contextvars
16
16
17
- from runtype import dataclass
17
+ import attrs
18
18
from typing_extensions import Self
19
19
20
20
from data_diff .abcs .compiler import AbstractCompiler
@@ -90,12 +90,7 @@ class CompileError(Exception):
90
90
pass
91
91
92
92
93
- # TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved.
94
- class _RuntypeHackToFixCicularRefrencedDatabase :
95
- dialect : "BaseDialect"
96
-
97
-
98
- @dataclass
93
+ @attrs .define (frozen = True )
99
94
class Compiler (AbstractCompiler ):
100
95
"""
101
96
Compiler bears the context for a single compilation.
@@ -107,16 +102,16 @@ class Compiler(AbstractCompiler):
107
102
# Database is needed to normalize tables. Dialect is needed for recursive compilations.
108
103
# In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
109
104
# In practice, we currently bind the dialects to the specific database classes.
110
- database : _RuntypeHackToFixCicularRefrencedDatabase
105
+ database : "Database"
111
106
112
107
in_select : bool = False # Compilation runtime flag
113
108
in_join : bool = False # Compilation runtime flag
114
109
115
- _table_context : List = field (default_factory = list ) # List[ITable]
116
- _subqueries : Dict [str , Any ] = field (default_factory = dict ) # XXX not thread-safe
110
+ _table_context : List = attrs . field (factory = list ) # List[ITable]
111
+ _subqueries : Dict [str , Any ] = attrs . field (factory = dict ) # XXX not thread-safe
117
112
root : bool = True
118
113
119
- _counter : List = field (default_factory = lambda : [0 ])
114
+ _counter : List = attrs . field (factory = lambda : [0 ])
120
115
121
116
@property
122
117
def dialect (self ) -> "BaseDialect" :
@@ -136,7 +131,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
136
131
return self .database .dialect .parse_table_name (table_name )
137
132
138
133
def add_table_context (self , * tables : Sequence , ** kw ) -> Self :
139
- return self . replace ( _table_context = self ._table_context + list (tables ), ** kw )
134
+ return attrs . evolve ( self , table_context = self ._table_context + list (tables ), ** kw )
140
135
141
136
142
137
def parse_table_name (t ):
@@ -272,7 +267,7 @@ def _compile(self, compiler: Compiler, elem) -> str:
272
267
if elem is None :
273
268
return "NULL"
274
269
elif isinstance (elem , Compilable ):
275
- return self .render_compilable (compiler . replace ( root = False ), elem )
270
+ return self .render_compilable (attrs . evolve ( compiler , root = False ), elem )
276
271
elif isinstance (elem , str ):
277
272
return f"'{ elem } '"
278
273
elif isinstance (elem , (int , float )):
@@ -382,7 +377,7 @@ def render_column(self, c: Compiler, elem: Column) -> str:
382
377
return self .quote (elem .name )
383
378
384
379
def render_cte (self , parent_c : Compiler , elem : Cte ) -> str :
385
- c : Compiler = parent_c . replace ( _table_context = [], in_select = False )
380
+ c : Compiler = attrs . evolve ( parent_c , table_context = [], in_select = False )
386
381
compiled = self .compile (c , elem .source_table )
387
382
388
383
name = elem .name or parent_c .new_unique_name ()
@@ -495,7 +490,7 @@ def render_tablealias(self, c: Compiler, elem: TableAlias) -> str:
495
490
return f"{ self .compile (c , elem .source_table )} { self .quote (elem .name )} "
496
491
497
492
def render_tableop (self , parent_c : Compiler , elem : TableOp ) -> str :
498
- c : Compiler = parent_c . replace ( in_select = False )
493
+ c : Compiler = attrs . evolve ( parent_c , in_select = False )
499
494
table_expr = f"{ self .compile (c , elem .table1 )} { elem .op } { self .compile (c , elem .table2 )} "
500
495
if parent_c .in_select :
501
496
table_expr = f"({ table_expr } ) { c .new_unique_name ()} "
@@ -507,7 +502,7 @@ def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str:
507
502
return self .compile (c , elem ._get_resolved ())
508
503
509
504
def render_select (self , parent_c : Compiler , elem : Select ) -> str :
510
- c : Compiler = parent_c . replace ( in_select = True ) # .add_table_context(self.table)
505
+ c : Compiler = attrs . evolve ( parent_c , in_select = True ) # .add_table_context(self.table)
511
506
compile_fn = functools .partial (self .compile , c )
512
507
513
508
columns = ", " .join (map (compile_fn , elem .columns )) if elem .columns else "*"
@@ -545,7 +540,8 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str:
545
540
546
541
def render_join (self , parent_c : Compiler , elem : Join ) -> str :
547
542
tables = [
548
- t if isinstance (t , TableAlias ) else TableAlias (t , parent_c .new_unique_name ()) for t in elem .source_tables
543
+ t if isinstance (t , TableAlias ) else TableAlias (source_table = t , name = parent_c .new_unique_name ())
544
+ for t in elem .source_tables
549
545
]
550
546
c = parent_c .add_table_context (* tables , in_join = True , in_select = False )
551
547
op = " JOIN " if elem .op is None else f" { elem .op } JOIN "
@@ -578,7 +574,8 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
578
574
if isinstance (elem .table , Select ) and elem .table .columns is None and elem .table .group_by_exprs is None :
579
575
return self .compile (
580
576
c ,
581
- elem .table .replace (
577
+ attrs .evolve (
578
+ elem .table ,
582
579
columns = columns ,
583
580
group_by_exprs = [Code (k ) for k in keys ],
584
581
having_exprs = elem .having_exprs ,
@@ -590,7 +587,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
590
587
having_str = (
591
588
" HAVING " + " AND " .join (map (compile_fn , elem .having_exprs )) if elem .having_exprs is not None else ""
592
589
)
593
- select = f"SELECT { columns_str } FROM { self .compile (c . replace ( in_select = True ), elem .table )} GROUP BY { keys_str } { having_str } "
590
+ select = f"SELECT { columns_str } FROM { self .compile (attrs . evolve ( c , in_select = True ), elem .table )} GROUP BY { keys_str } { having_str } "
594
591
595
592
if c .in_select :
596
593
select = f"({ select } ) { c .new_unique_name ()} "
@@ -827,7 +824,7 @@ def set_timezone_to_utc(self) -> str:
827
824
T = TypeVar ("T" , bound = BaseDialect )
828
825
829
826
830
- @dataclass
827
+ @attrs . define ( frozen = True )
831
828
class QueryResult :
832
829
rows : list
833
830
columns : Optional [list ] = None
@@ -842,7 +839,7 @@ def __getitem__(self, i):
842
839
return self .rows [i ]
843
840
844
841
845
- class Database (abc .ABC , _RuntypeHackToFixCicularRefrencedDatabase ):
842
+ class Database (abc .ABC ):
846
843
"""Base abstract class for databases.
847
844
848
845
Used for providing connection code and implementation specific SQL utilities.
0 commit comments