1
1
import abc
2
2
import functools
3
- from dataclasses import field
3
+ import random
4
4
from datetime import datetime
5
5
import math
6
6
import sys
14
14
import decimal
15
15
import contextvars
16
16
17
- from runtype import dataclass
17
+ import attrs
18
18
from typing_extensions import Self
19
19
20
20
from data_diff .abcs .compiler import AbstractCompiler
@@ -90,12 +90,7 @@ class CompileError(Exception):
90
90
pass
91
91
92
92
93
- # TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved.
94
- class _RuntypeHackToFixCicularRefrencedDatabase :
95
- dialect : "BaseDialect"
96
-
97
-
98
- @dataclass
93
+ @attrs .define (frozen = True )
99
94
class Compiler (AbstractCompiler ):
100
95
"""
101
96
Compiler bears the context for a single compilation.
@@ -107,16 +102,16 @@ class Compiler(AbstractCompiler):
107
102
# Database is needed to normalize tables. Dialect is needed for recursive compilations.
108
103
# In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
109
104
# In practice, we currently bind the dialects to the specific database classes.
110
- database : _RuntypeHackToFixCicularRefrencedDatabase
105
+ database : "Database"
111
106
112
107
in_select : bool = False # Compilation runtime flag
113
108
in_join : bool = False # Compilation runtime flag
114
109
115
- _table_context : List = field (default_factory = list ) # List[ITable]
116
- _subqueries : Dict [str , Any ] = field (default_factory = dict ) # XXX not thread-safe
110
+ _table_context : List = attrs . field (factory = list ) # List[ITable]
111
+ _subqueries : Dict [str , Any ] = attrs . field (factory = dict ) # XXX not thread-safe
117
112
root : bool = True
118
113
119
- _counter : List = field (default_factory = lambda : [0 ])
114
+ _counter : List = attrs . field (factory = lambda : [0 ])
120
115
121
116
@property
122
117
def dialect (self ) -> "BaseDialect" :
@@ -136,7 +131,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
136
131
return self .database .dialect .parse_table_name (table_name )
137
132
138
133
def add_table_context (self , * tables : Sequence , ** kw ) -> Self :
139
- return self . replace ( _table_context = self ._table_context + list (tables ), ** kw )
134
+ return attrs . evolve ( self , table_context = self ._table_context + list (tables ), ** kw )
140
135
141
136
142
137
def parse_table_name (t ):
@@ -271,7 +266,7 @@ def _compile(self, compiler: Compiler, elem) -> str:
271
266
if elem is None :
272
267
return "NULL"
273
268
elif isinstance (elem , Compilable ):
274
- return self .render_compilable (compiler . replace ( root = False ), elem )
269
+ return self .render_compilable (attrs . evolve ( compiler , root = False ), elem )
275
270
elif isinstance (elem , str ):
276
271
return f"'{ elem } '"
277
272
elif isinstance (elem , (int , float )):
@@ -381,7 +376,7 @@ def render_column(self, c: Compiler, elem: Column) -> str:
381
376
return self .quote (elem .name )
382
377
383
378
def render_cte (self , parent_c : Compiler , elem : Cte ) -> str :
384
- c : Compiler = parent_c . replace ( _table_context = [], in_select = False )
379
+ c : Compiler = attrs . evolve ( parent_c , table_context = [], in_select = False )
385
380
compiled = self .compile (c , elem .source_table )
386
381
387
382
name = elem .name or parent_c .new_unique_name ()
@@ -494,7 +489,7 @@ def render_tablealias(self, c: Compiler, elem: TableAlias) -> str:
494
489
return f"{ self .compile (c , elem .source_table )} { self .quote (elem .name )} "
495
490
496
491
def render_tableop (self , parent_c : Compiler , elem : TableOp ) -> str :
497
- c : Compiler = parent_c . replace ( in_select = False )
492
+ c : Compiler = attrs . evolve ( parent_c , in_select = False )
498
493
table_expr = f"{ self .compile (c , elem .table1 )} { elem .op } { self .compile (c , elem .table2 )} "
499
494
if parent_c .in_select :
500
495
table_expr = f"({ table_expr } ) { c .new_unique_name ()} "
@@ -506,7 +501,7 @@ def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str:
506
501
return self .compile (c , elem ._get_resolved ())
507
502
508
503
def render_select (self , parent_c : Compiler , elem : Select ) -> str :
509
- c : Compiler = parent_c . replace ( in_select = True ) # .add_table_context(self.table)
504
+ c : Compiler = attrs . evolve ( parent_c , in_select = True ) # .add_table_context(self.table)
510
505
compile_fn = functools .partial (self .compile , c )
511
506
512
507
columns = ", " .join (map (compile_fn , elem .columns )) if elem .columns else "*"
@@ -544,7 +539,8 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str:
544
539
545
540
def render_join (self , parent_c : Compiler , elem : Join ) -> str :
546
541
tables = [
547
- t if isinstance (t , TableAlias ) else TableAlias (t , parent_c .new_unique_name ()) for t in elem .source_tables
542
+ t if isinstance (t , TableAlias ) else TableAlias (source_table = t , name = parent_c .new_unique_name ())
543
+ for t in elem .source_tables
548
544
]
549
545
c = parent_c .add_table_context (* tables , in_join = True , in_select = False )
550
546
op = " JOIN " if elem .op is None else f" { elem .op } JOIN "
@@ -577,7 +573,8 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
577
573
if isinstance (elem .table , Select ) and elem .table .columns is None and elem .table .group_by_exprs is None :
578
574
return self .compile (
579
575
c ,
580
- elem .table .replace (
576
+ attrs .evolve (
577
+ elem .table ,
581
578
columns = columns ,
582
579
group_by_exprs = [Code (k ) for k in keys ],
583
580
having_exprs = elem .having_exprs ,
@@ -589,7 +586,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
589
586
having_str = (
590
587
" HAVING " + " AND " .join (map (compile_fn , elem .having_exprs )) if elem .having_exprs is not None else ""
591
588
)
592
- select = f"SELECT { columns_str } FROM { self .compile (c . replace ( in_select = True ), elem .table )} GROUP BY { keys_str } { having_str } "
589
+ select = f"SELECT { columns_str } FROM { self .compile (attrs . evolve ( c , in_select = True ), elem .table )} GROUP BY { keys_str } { having_str } "
593
590
594
591
if c .in_select :
595
592
select = f"({ select } ) { c .new_unique_name ()} "
@@ -815,7 +812,7 @@ def set_timezone_to_utc(self) -> str:
815
812
T = TypeVar ("T" , bound = BaseDialect )
816
813
817
814
818
- @dataclass
815
+ @attrs . define ( frozen = True )
819
816
class QueryResult :
820
817
rows : list
821
818
columns : Optional [list ] = None
@@ -830,7 +827,7 @@ def __getitem__(self, i):
830
827
return self .rows [i ]
831
828
832
829
833
- class Database (abc .ABC , _RuntypeHackToFixCicularRefrencedDatabase ):
830
+ class Database (abc .ABC ):
834
831
"""Base abstract class for databases.
835
832
836
833
Used for providing connection code and implementation specific SQL utilities.
0 commit comments