From 651b8bcd79ff2be9d8e99194efed9654d5f5213f Mon Sep 17 00:00:00 2001
From: Sergey Vasilyev <sv@datafold.com>
Date: Mon, 25 Sep 2023 15:38:27 +0200
Subject: [PATCH 1/4] Use dialect's quoting directly, not via compiler

---
 data_diff/queries/ast_classes.py | 14 +++++++-------
 data_diff/queries/compiler.py    |  3 ---
 2 files changed, 7 insertions(+), 10 deletions(-)

diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py
index 70cb355f..954783de 100644
--- a/data_diff/queries/ast_classes.py
+++ b/data_diff/queries/ast_classes.py
@@ -74,7 +74,7 @@ class Alias(ExprNode):
     name: str
 
     def compile(self, c: Compiler) -> str:
-        return f"{c.compile(self.expr)} AS {c.quote(self.name)}"
+        return f"{c.compile(self.expr)} AS {c.dialect.quote(self.name)}"
 
     @property
     def type(self):
@@ -408,14 +408,14 @@ def compile(self, c: Compiler) -> str:
                     t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is self.source_table
                 ]
                 if not aliases:
-                    return c.quote(self.name)
+                    return c.dialect.quote(self.name)
                 elif len(aliases) > 1:
                     raise CompileError(f"Too many aliases for column {self.name}")
                 (alias,) = aliases
 
-                return f"{c.quote(alias.name)}.{c.quote(self.name)}"
+                return f"{c.dialect.quote(alias.name)}.{c.dialect.quote(self.name)}"
 
-        return c.quote(self.name)
+        return c.dialect.quote(self.name)
 
 
 @dataclass
@@ -429,7 +429,7 @@ def source_table(self) -> Self:
 
     def compile(self, c: Compiler) -> str:
         path = self.path  # c.database._normalize_table_path(self.name)
-        return ".".join(map(c.quote, path))
+        return ".".join(map(c.dialect.quote, path))
 
     # Statement shorthands
     def create(self, source_table: ITable = None, *, if_not_exists: bool = False, primary_keys: List[str] = None):
@@ -515,7 +515,7 @@ class TableAlias(ExprNode, ITable):
     name: str
 
     def compile(self, c: Compiler) -> str:
-        return f"{c.compile(self.source_table)} {c.quote(self.name)}"
+        return f"{c.compile(self.source_table)} {c.dialect.quote(self.name)}"
 
 
 @dataclass
@@ -989,7 +989,7 @@ def compile(self, c: Compiler) -> str:
         else:
             expr = c.compile(self.expr)
 
-        columns = "(%s)" % ", ".join(map(c.quote, self.columns)) if self.columns is not None else ""
+        columns = "(%s)" % ", ".join(map(c.dialect.quote, self.columns)) if self.columns is not None else ""
 
         return f"INSERT INTO {c.compile(self.path)}{columns} {expr}"
 
diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py
index e6246236..a08d85ea 100644
--- a/data_diff/queries/compiler.py
+++ b/data_diff/queries/compiler.py
@@ -83,6 +83,3 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
 
     def add_table_context(self, *tables: Sequence, **kw) -> Self:
         return self.replace(_table_context=self._table_context + list(tables), **kw)
-
-    def quote(self, s: str):
-        return self.dialect.quote(s)

From bf4eeea5d0876457f1b4a541f597e28f7bfbeb5f Mon Sep 17 00:00:00 2001
From: Sergey Vasilyev <sv@datafold.com>
Date: Mon, 25 Sep 2023 15:04:01 +0200
Subject: [PATCH 2/4] Move table name parsing to dialects, where they
 semantically belong

---
 data_diff/__init__.py             | 8 +++++---
 data_diff/__main__.py             | 9 +++++----
 data_diff/abcs/database_types.py  | 8 ++++----
 data_diff/databases/base.py       | 6 +++---
 data_diff/databases/bigquery.py   | 8 ++++----
 data_diff/databases/databricks.py | 8 ++++----
 data_diff/queries/compiler.py     | 3 ++-
 7 files changed, 27 insertions(+), 23 deletions(-)

diff --git a/data_diff/__init__.py b/data_diff/__init__.py
index 60c79b10..4ae223fb 100644
--- a/data_diff/__init__.py
+++ b/data_diff/__init__.py
@@ -1,6 +1,7 @@
 from typing import Sequence, Tuple, Iterator, Optional, Union
 
 from data_diff.abcs.database_types import DbTime, DbPath
+from data_diff.databases import Database
 from data_diff.tracking import disable_tracking
 from data_diff.databases._connect import connect
 from data_diff.diff_tables import Algorithm
@@ -31,10 +32,10 @@ def connect_to_table(
     if isinstance(key_columns, str):
         key_columns = (key_columns,)
 
-    db = connect(db_info, thread_count=thread_count)
+    db: Database = connect(db_info, thread_count=thread_count)
 
     if isinstance(table_name, str):
-        table_name = db.parse_table_name(table_name)
+        table_name = db.dialect.parse_table_name(table_name)
 
     return TableSegment(db, table_name, key_columns, **kwargs)
 
@@ -161,7 +162,8 @@ def diff_tables(
         )
     elif algorithm == Algorithm.JOINDIFF:
         if isinstance(materialize_to_table, str):
-            materialize_to_table = table1.database.parse_table_name(eval_name_template(materialize_to_table))
+            table_name = eval_name_template(materialize_to_table)
+            materialize_to_table = table1.database.dialect.parse_table_name(table_name)
         differ = JoinDiffer(
             threaded=threaded,
             max_threadpool_size=max_threadpool_size,
diff --git a/data_diff/__main__.py b/data_diff/__main__.py
index 77dc7fb6..0e5255e6 100644
--- a/data_diff/__main__.py
+++ b/data_diff/__main__.py
@@ -6,12 +6,13 @@
 import json
 import logging
 from itertools import islice
-from typing import Dict, Optional
+from typing import Dict, Optional, Tuple
 
 import rich
 from rich.logging import RichHandler
 import click
 
+from data_diff import Database
 from data_diff.schema import create_schema
 from data_diff.queries.api import current_timestamp
 
@@ -425,7 +426,7 @@ def _data_diff(
             logging.error(f"Error while parsing age expression: {e}")
             return
 
-    dbs = db1, db2
+    dbs: Tuple[Database, Database] = db1, db2
 
     if interactive:
         for db in dbs:
@@ -444,7 +445,7 @@ def _data_diff(
             materialize_all_rows=materialize_all_rows,
             table_write_limit=table_write_limit,
             materialize_to_table=materialize_to_table
-            and db1.parse_table_name(eval_name_template(materialize_to_table)),
+            and db1.dialect.parse_table_name(eval_name_template(materialize_to_table)),
         )
     else:
         assert algorithm == Algorithm.HASHDIFF
@@ -456,7 +457,7 @@ def _data_diff(
         )
 
     table_names = table1, table2
-    table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)]
+    table_paths = [db.dialect.parse_table_name(t) for db, t in safezip(dbs, table_names)]
 
     schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths)))
     schema1, schema2 = schemas = [
diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py
index 82ec8352..c811ace5 100644
--- a/data_diff/abcs/database_types.py
+++ b/data_diff/abcs/database_types.py
@@ -176,6 +176,10 @@ class UnknownColType(ColType):
 class AbstractDialect(ABC):
     """Dialect-dependent query expressions"""
 
+    @abstractmethod
+    def parse_table_name(self, name: str) -> DbPath:
+        "Parse the given table name into a DbPath"
+
     @property
     @abstractmethod
     def name(self) -> str:
@@ -319,10 +323,6 @@ def _process_table_schema(
 
         """
 
-    @abstractmethod
-    def parse_table_name(self, name: str) -> DbPath:
-        "Parse the given table name into a DbPath"
-
     @abstractmethod
     def close(self):
         "Close connection(s) to the database instance. Querying will stop functioning."
diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py
index a89ab74e..082e9815 100644
--- a/data_diff/databases/base.py
+++ b/data_diff/databases/base.py
@@ -156,6 +156,9 @@ class BaseDialect(AbstractDialect):
 
     PLACEHOLDER_TABLE = None  # Used for Oracle
 
+    def parse_table_name(self, name: str) -> DbPath:
+        return parse_table_name(name)
+
     def offset_limit(
         self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
     ) -> str:
@@ -518,9 +521,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
 
         raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table")
 
-    def parse_table_name(self, name: str) -> DbPath:
-        return parse_table_name(name)
-
     def _query_cursor(self, c, sql_code: str) -> QueryResult:
         assert isinstance(sql_code, str), sql_code
         try:
diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py
index 5925234f..feb98bde 100644
--- a/data_diff/databases/bigquery.py
+++ b/data_diff/databases/bigquery.py
@@ -212,6 +212,10 @@ def to_comparable(self, value: str, coltype: ColType) -> str:
     def set_timezone_to_utc(self) -> str:
         raise NotImplementedError()
 
+    def parse_table_name(self, name: str) -> DbPath:
+        path = parse_table_name(name)
+        return tuple(i for i in path if i is not None)
+
 
 class BigQuery(Database):
     CONNECT_URI_HELP = "bigquery://<project>/<dataset>"
@@ -288,10 +292,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
                 f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table"
             )
 
-    def parse_table_name(self, name: str) -> DbPath:
-        path = parse_table_name(name)
-        return tuple(i for i in self._normalize_table_path(path) if i is not None)
-
     @property
     def is_autocommit(self) -> bool:
         return True
diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py
index 1b8aa33a..67d0528d 100644
--- a/data_diff/databases/databricks.py
+++ b/data_diff/databases/databricks.py
@@ -94,6 +94,10 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
     def set_timezone_to_utc(self) -> str:
         return "SET TIME ZONE 'UTC'"
 
+    def parse_table_name(self, name: str) -> DbPath:
+        path = parse_table_name(name)
+        return tuple(i for i in path if i is not None)
+
 
 class Databricks(ThreadedDatabase):
     dialect = Dialect()
@@ -178,10 +182,6 @@ def _process_table_schema(
         self._refine_coltypes(path, col_dict, where)
         return col_dict
 
-    def parse_table_name(self, name: str) -> DbPath:
-        path = parse_table_name(name)
-        return tuple(i for i in self._normalize_table_path(path) if i is not None)
-
     @property
     def is_autocommit(self) -> bool:
         return True
diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py
index a08d85ea..14eb0b77 100644
--- a/data_diff/queries/compiler.py
+++ b/data_diff/queries/compiler.py
@@ -79,7 +79,8 @@ def new_unique_name(self, prefix="tmp"):
 
     def new_unique_table_name(self, prefix="tmp") -> DbPath:
         self._counter[0] += 1
-        return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}")
+        table_name = f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}"
+        return self.database.dialect.parse_table_name(table_name)
 
     def add_table_context(self, *tables: Sequence, **kw) -> Self:
         return self.replace(_table_context=self._table_context + list(tables), **kw)

From b8cf4827d02fb6fd0491c00af56343260699224b Mon Sep 17 00:00:00 2001
From: Sergey Vasilyev <sv@datafold.com>
Date: Mon, 25 Sep 2023 16:05:28 +0200
Subject: [PATCH 3/4] Remove compiler's unused params

---
 data_diff/queries/compiler.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py
index 14eb0b77..b0df04eb 100644
--- a/data_diff/queries/compiler.py
+++ b/data_diff/queries/compiler.py
@@ -26,7 +26,7 @@ class Root:
 @dataclass
 class Compiler(AbstractCompiler):
     database: AbstractDatabase
-    params: dict = field(default_factory=dict)
+
     in_select: bool = False  # Compilation runtime flag
     in_join: bool = False  # Compilation runtime flag
 

From 795bb0ec2a6ad95104a6a37f003a33b222f94c42 Mon Sep 17 00:00:00 2001
From: Sergey Vasilyev <sv@datafold.com>
Date: Mon, 25 Sep 2023 15:34:47 +0200
Subject: [PATCH 4/4] Compile all AST elements always via dialects, never
 directly
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The root authority on how to do the SQL syntax properly is the dialect — not the AST element itself. AST tree must only carry the intentions of what we want to execute, but not how it should/could be executed.

This makes the AST truly independent of dialects and databases, allowing us to:
- Focus on the main logic regardless of the SQL capabilities.
- Create custom database connectors without involving AST changes every time.

Before the change, adding a new database connector with unusual syntax would often require changing the AST elements to direct to `compiler.dialect.some_method()` — with only a subset of SQL being described in dialects. The other rather arbitrary part of SQL syntax was hard-coded in AST elements and could not be easily overridden without such changes.

After the change, all the SQL logic is concentrated in one hierarchy of dialects, mostly in one base class.
---
 data_diff/abcs/compiler.py       |  12 +-
 data_diff/abcs/database_types.py |   7 +-
 data_diff/bound_exprs.py         |   5 -
 data_diff/databases/base.py      | 449 ++++++++++++++++++++++++++++++-
 data_diff/joindiff_tables.py     |  10 +-
 data_diff/queries/ast_classes.py | 273 +------------------
 data_diff/queries/compiler.py    |  55 +---
 data_diff/queries/extras.py      |  41 +--
 8 files changed, 475 insertions(+), 377 deletions(-)

diff --git a/data_diff/abcs/compiler.py b/data_diff/abcs/compiler.py
index 72fd7578..4a847d05 100644
--- a/data_diff/abcs/compiler.py
+++ b/data_diff/abcs/compiler.py
@@ -1,15 +1,9 @@
-from typing import Any, Dict
-from abc import ABC, abstractmethod
+from abc import ABC
 
 
 class AbstractCompiler(ABC):
-    @abstractmethod
-    def compile(self, elem: Any, params: Dict[str, Any] = None) -> str:
-        ...
+    pass
 
 
 class Compilable(ABC):
-    # TODO generic syntax, so we can write Compilable[T] for expressions returning a value of type T
-    @abstractmethod
-    def compile(self, c: AbstractCompiler) -> str:
-        ...
+    pass
diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py
index c811ace5..a679db67 100644
--- a/data_diff/abcs/database_types.py
+++ b/data_diff/abcs/database_types.py
@@ -1,11 +1,12 @@
 import decimal
 from abc import ABC, abstractmethod
-from typing import Sequence, Optional, Tuple, Type, Union, Dict, List
+from typing import Sequence, Optional, Tuple, Union, Dict, List
 from datetime import datetime
 
 from runtype import dataclass
 from typing_extensions import Self
 
+from data_diff.abcs.compiler import AbstractCompiler
 from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown
 
 
@@ -176,6 +177,10 @@ class UnknownColType(ColType):
 class AbstractDialect(ABC):
     """Dialect-dependent query expressions"""
 
+    @abstractmethod
+    def compile(self, compiler: AbstractCompiler, elem, params=None) -> str:
+        raise NotImplementedError
+
     @abstractmethod
     def parse_table_name(self, name: str) -> DbPath:
         "Parse the given table name into a DbPath"
diff --git a/data_diff/bound_exprs.py b/data_diff/bound_exprs.py
index 1742b74c..4b53846d 100644
--- a/data_diff/bound_exprs.py
+++ b/data_diff/bound_exprs.py
@@ -8,7 +8,6 @@
 from typing_extensions import Self
 
 from data_diff.abcs.database_types import AbstractDatabase
-from data_diff.abcs.compiler import AbstractCompiler
 from data_diff.queries.ast_classes import ExprNode, TablePath, Compilable
 from data_diff.queries.api import table
 from data_diff.schema import create_schema
@@ -37,10 +36,6 @@ def query(self, res_type=list):
     def type(self):
         return self.node.type
 
-    def compile(self, c: AbstractCompiler) -> str:
-        assert c.database is self.database
-        return self.node.compile(c)
-
 
 def bind_node(node, database):
     return BoundNode(database, node)
diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py
index 082e9815..dc43d8d7 100644
--- a/data_diff/databases/base.py
+++ b/data_diff/databases/base.py
@@ -1,3 +1,4 @@
+import functools
 from datetime import datetime
 import math
 import sys
@@ -9,13 +10,25 @@
 from abc import abstractmethod
 from uuid import UUID
 import decimal
+import contextvars
 
 from runtype import dataclass
 from typing_extensions import Self
 
-from data_diff.utils import is_uuid, safezip
+from data_diff.queries.compiler import CompileError
+from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString
+from data_diff.utils import ArithString, is_uuid, join_iter, safezip
 from data_diff.queries.api import Expr, Compiler, table, Select, SKIP, Explain, Code, this
-from data_diff.queries.ast_classes import Random
+from data_diff.queries.ast_classes import Alias, BinOp, CaseWhen, Cast, Column, Commit, Concat, ConstantTable, Count, \
+    CreateTable, Cte, \
+    CurrentTimestamp, DropTable, Func, \
+    GroupBy, \
+    In, InsertToTable, IsDistinctFrom, \
+    Join, \
+    Param, \
+    Random, \
+    Root, TableAlias, TableOp, TablePath, TestRegex, \
+    TimeTravel, TruncateTable, UnaryOp, WhenThen, _ResolveColumn
 from data_diff.abcs.database_types import (
     AbstractDatabase,
     Array,
@@ -39,16 +52,17 @@
     Boolean,
     JSON,
 )
-from data_diff.abcs.mixins import Compilable
+from data_diff.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel, Compilable
 from data_diff.abcs.mixins import (
     AbstractMixin_Schema,
     AbstractMixin_RandomSample,
     AbstractMixin_NormalizeValue,
     AbstractMixin_OptimizerHints,
 )
-from data_diff.bound_exprs import bound_table
+from data_diff.bound_exprs import BoundNode, bound_table
 
 logger = logging.getLogger("database")
+cv_params = contextvars.ContextVar("params")
 
 
 def parse_table_name(t):
@@ -98,7 +112,7 @@ def __init__(self, compiler: Compiler, gen: Generator):
     def apply_queries(self, callback: Callable[[str], Any]):
         q: Expr = next(self.gen)
         while True:
-            sql = self.compiler.compile(q)
+            sql = self.compiler.database.dialect.compile(self.compiler, q)
             logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql)
             try:
                 try:
@@ -159,6 +173,424 @@ class BaseDialect(AbstractDialect):
     def parse_table_name(self, name: str) -> DbPath:
         return parse_table_name(name)
 
+    def compile(self, compiler: Compiler, elem, params=None) -> str:
+        if params:
+            cv_params.set(params)
+
+        if compiler.root and isinstance(elem, Compilable) and not isinstance(elem, Root):
+            from data_diff.queries.ast_classes import Select
+
+            elem = Select(columns=[elem])
+
+        res = self._compile(compiler, elem)
+        if compiler.root and compiler._subqueries:
+            subq = ", ".join(f"\n  {k} AS ({v})" for k, v in compiler._subqueries.items())
+            compiler._subqueries.clear()
+            return f"WITH {subq}\n{res}"
+        return res
+
+    def _compile(self, compiler: Compiler, elem) -> str:
+        if elem is None:
+            return "NULL"
+        elif isinstance(elem, Compilable):
+            return self.render_compilable(compiler.replace(root=False), elem)
+        elif isinstance(elem, str):
+            return f"'{elem}'"
+        elif isinstance(elem, (int, float)):
+            return str(elem)
+        elif isinstance(elem, datetime):
+            return self.timestamp_value(elem)
+        elif isinstance(elem, bytes):
+            return f"b'{elem.decode()}'"
+        elif isinstance(elem, ArithString):
+            return f"'{elem}'"
+        assert False, elem
+
+    def render_compilable(self, c: Compiler, elem: Compilable) -> str:
+        # All ifs are only for better code navigation, IDE usage detection, and type checking.
+        # The last catch-all would render them anyway — it is a typical "visitor" pattern.
+        if isinstance(elem, Column):
+            return self.render_column(c, elem)
+        elif isinstance(elem, Cte):
+            return self.render_cte(c, elem)
+        elif isinstance(elem, Commit):
+            return self.render_commit(c, elem)
+        elif isinstance(elem, Param):
+            return self.render_param(c, elem)
+        elif isinstance(elem, NormalizeAsString):
+            return self.render_normalizeasstring(c, elem)
+        elif isinstance(elem, ApplyFuncAndNormalizeAsString):
+            return self.render_applyfuncandnormalizeasstring(c, elem)
+        elif isinstance(elem, Checksum):
+            return self.render_checksum(c, elem)
+        elif isinstance(elem, Concat):
+            return self.render_concat(c, elem)
+        elif isinstance(elem, TestRegex):
+            return self.render_testregex(c, elem)
+        elif isinstance(elem, Func):
+            return self.render_func(c, elem)
+        elif isinstance(elem, WhenThen):
+            return self.render_whenthen(c, elem)
+        elif isinstance(elem, CaseWhen):
+            return self.render_casewhen(c, elem)
+        elif isinstance(elem, IsDistinctFrom):
+            return self.render_isdistinctfrom(c, elem)
+        elif isinstance(elem, UnaryOp):
+            return self.render_unaryop(c, elem)
+        elif isinstance(elem, BinOp):
+            return self.render_binop(c, elem)
+        elif isinstance(elem, TablePath):
+            return self.render_tablepath(c, elem)
+        elif isinstance(elem, TableAlias):
+            return self.render_tablealias(c, elem)
+        elif isinstance(elem, TableOp):
+            return self.render_tableop(c, elem)
+        elif isinstance(elem, Select):
+            return self.render_select(c, elem)
+        elif isinstance(elem, Join):
+            return self.render_join(c, elem)
+        elif isinstance(elem, GroupBy):
+            return self.render_groupby(c, elem)
+        elif isinstance(elem, Count):
+            return self.render_count(c, elem)
+        elif isinstance(elem, Alias):
+            return self.render_alias(c, elem)
+        elif isinstance(elem, In):
+            return self.render_in(c, elem)
+        elif isinstance(elem, Cast):
+            return self.render_cast(c, elem)
+        elif isinstance(elem, Random):
+            return self.render_random(c, elem)
+        elif isinstance(elem, Explain):
+            return self.render_explain(c, elem)
+        elif isinstance(elem, CurrentTimestamp):
+            return self.render_currenttimestamp(c, elem)
+        elif isinstance(elem, TimeTravel):
+            return self.render_timetravel(c, elem)
+        elif isinstance(elem, CreateTable):
+            return self.render_createtable(c, elem)
+        elif isinstance(elem, DropTable):
+            return self.render_droptable(c, elem)
+        elif isinstance(elem, TruncateTable):
+            return self.render_truncatetable(c, elem)
+        elif isinstance(elem, InsertToTable):
+            return self.render_inserttotable(c, elem)
+        elif isinstance(elem, Code):
+            return self.render_code(c, elem)
+        elif isinstance(elem, BoundNode):
+            return self.render_boundnode(c, elem)
+        elif isinstance(elem, _ResolveColumn):
+            return self.render__resolvecolumn(c, elem)
+
+        method_name = f"render_{elem.__class__.__name__.lower()}"
+        method = getattr(self, method_name, None)
+        if method is not None:
+            return method(c, elem)
+        else:
+            raise RuntimeError(f"Cannot render AST of type {elem.__class__}")
+        # return elem.compile(compiler.replace(root=False))
+
+    def render_column(self, c: Compiler, elem: Column) -> str:
+        if c._table_context:
+            if len(c._table_context) > 1:
+                aliases = [
+                    t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is elem.source_table
+                ]
+                if not aliases:
+                    return self.quote(elem.name)
+                elif len(aliases) > 1:
+                    raise CompileError(f"Too many aliases for column {elem.name}")
+                (alias,) = aliases
+
+                return f"{self.quote(alias.name)}.{self.quote(elem.name)}"
+
+        return self.quote(elem.name)
+
+    def render_cte(self, parent_c: Compiler, elem: Cte) -> str:
+        c: Compiler = parent_c.replace(_table_context=[], in_select=False)
+        compiled = self.compile(c, elem.source_table)
+
+        name = elem.name or parent_c.new_unique_name()
+        name_params = f"{name}({', '.join(elem.params)})" if elem.params else name
+        parent_c._subqueries[name_params] = compiled
+
+        return name
+
+    def render_commit(self, c: Compiler, elem: Commit) -> str:
+        return "COMMIT" if not c.database.is_autocommit else SKIP
+
+    def render_param(self, c: Compiler, elem: Param) -> str:
+        params = cv_params.get()
+        return self._compile(c, params[elem.name])
+
+    def render_normalizeasstring(self, c: Compiler, elem: NormalizeAsString) -> str:
+        expr = self.compile(c, elem.expr)
+        return self.normalize_value_by_type(expr, elem.expr_type or elem.expr.type)
+
+    def render_applyfuncandnormalizeasstring(self, c: Compiler, elem: ApplyFuncAndNormalizeAsString) -> str:
+        expr = elem.expr
+        expr_type = expr.type
+
+        if isinstance(expr_type, Native_UUID):
+            # Normalize first, apply template after (for uuids)
+            # Needed because min/max(uuid) fails in postgresql
+            expr = NormalizeAsString(expr, expr_type)
+            if elem.apply_func is not None:
+                expr = elem.apply_func(expr)  # Apply template using Python's string formatting
+
+        else:
+            # Apply template before normalizing (for ints)
+            if elem.apply_func is not None:
+                expr = elem.apply_func(expr)  # Apply template using Python's string formatting
+            expr = NormalizeAsString(expr, expr_type)
+
+        return self.compile(c, expr)
+
+    def render_checksum(self, c: Compiler, elem: Checksum) -> str:
+        if len(elem.exprs) > 1:
+            exprs = [Code(f"coalesce({self.compile(c, expr)}, '<null>')") for expr in elem.exprs]
+            # exprs = [self.compile(c, e) for e in exprs]
+            expr = Concat(exprs, "|")
+        else:
+            # No need to coalesce - safe to assume that key cannot be null
+            (expr,) = elem.exprs
+        expr = self.compile(c, expr)
+        md5 = self.md5_as_int(expr)
+        return f"sum({md5})"
+
+    def render_concat(self, c: Compiler, elem: Concat) -> str:
+        # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
+        items = [f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')" for expr in elem.exprs]
+        assert items
+        if len(items) == 1:
+            return items[0]
+
+        if elem.sep:
+            items = list(join_iter(f"'{elem.sep}'", items))
+        return self.concat(items)
+
+    def render_alias(self, c: Compiler, elem: Alias) -> str:
+        return f"{self.compile(c, elem.expr)} AS {self.quote(elem.name)}"
+
+    def render_testregex(self, c: Compiler, elem: TestRegex) -> str:
+        # TODO: move this method to that mixin! raise here instead, unconditionally.
+        if not isinstance(self, AbstractMixin_Regex):
+            raise NotImplementedError(f"No regex implementation for database '{c.dialect}'")
+        regex = self.test_regex(elem.string, elem.pattern)
+        return self.compile(c, regex)
+
+    def render_count(self, c: Compiler, elem: Count) -> str:
+        expr = self.compile(c, elem.expr) if elem.expr else "*"
+        if elem.distinct:
+            return f"count(distinct {expr})"
+        return f"count({expr})"
+
+    def render_code(self, c: Compiler, elem: Code) -> str:
+        if not elem.args:
+            return elem.code
+
+        args = {k: self.compile(c, v) for k, v in elem.args.items()}
+        return elem.code.format(**args)
+
+    def render_func(self, c: Compiler, elem: Func) -> str:
+        args = ", ".join(self.compile(c, e) for e in elem.args)
+        return f"{elem.name}({args})"
+
+    def render_whenthen(self, c: Compiler, elem: WhenThen) -> str:
+        return f"WHEN {self.compile(c, elem.when)} THEN {self.compile(c, elem.then)}"
+
+    def render_casewhen(self, c: Compiler, elem: CaseWhen) -> str:
+        assert elem.cases
+        when_thens = " ".join(self.compile(c, case) for case in elem.cases)
+        else_expr = (" ELSE " + self.compile(c, elem.else_expr)) if elem.else_expr is not None else ""
+        return f"CASE {when_thens}{else_expr} END"
+
+    def render_isdistinctfrom(self, c: Compiler, elem: IsDistinctFrom) -> str:
+        a = self.to_comparable(self.compile(c, elem.a), elem.a.type)
+        b = self.to_comparable(self.compile(c, elem.b), elem.b.type)
+        return self.is_distinct_from(a, b)
+
+    def render_unaryop(self, c: Compiler, elem: UnaryOp) -> str:
+        return f"({elem.op}{self.compile(c, elem.expr)})"
+
+    def render_binop(self, c: Compiler, elem: BinOp) -> str:
+        expr = f" {elem.op} ".join(self.compile(c, a) for a in elem.args)
+        return f"({expr})"
+
+    def render_tablepath(self, c: Compiler, elem: TablePath) -> str:
+        path = elem.path  # c.database._normalize_table_path(self.name)
+        return ".".join(map(self.quote, path))
+
+    def render_tablealias(self, c: Compiler, elem: TableAlias) -> str:
+        return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}"
+
+    def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str:
+        c: Compiler = parent_c.replace(in_select=False)
+        table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}"
+        if parent_c.in_select:
+            table_expr = f"({table_expr}) {c.new_unique_name()}"
+        elif parent_c.in_join:
+            table_expr = f"({table_expr})"
+        return table_expr
+
+    def render_boundnode(self, c: Compiler, elem: BoundNode) -> str:
+        assert self is elem.database.dialect
+        return self.compile(c, elem.node)
+
+    def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str:
+        return self.compile(c, elem._get_resolved())
+
+    def render_select(self, parent_c: Compiler, elem: Select) -> str:
+        c: Compiler = parent_c.replace(in_select=True)  # .add_table_context(self.table)
+        compile_fn = functools.partial(self.compile, c)
+
+        columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*"
+        distinct = "DISTINCT " if elem.distinct else ""
+        optimizer_hints = self.optimizer_hints(elem.optimizer_hints) if elem.optimizer_hints else ""
+        select = f"SELECT {optimizer_hints}{distinct}{columns}"
+
+        if elem.table:
+            select += " FROM " + self.compile(c, elem.table)
+        elif self.PLACEHOLDER_TABLE:
+            select += f" FROM {self.PLACEHOLDER_TABLE}"
+
+        if elem.where_exprs:
+            select += " WHERE " + " AND ".join(map(compile_fn, elem.where_exprs))
+
+        if elem.group_by_exprs:
+            select += " GROUP BY " + ", ".join(map(compile_fn, elem.group_by_exprs))
+
+        if elem.having_exprs:
+            assert elem.group_by_exprs
+            select += " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs))
+
+        if elem.order_by_exprs:
+            select += " ORDER BY " + ", ".join(map(compile_fn, elem.order_by_exprs))
+
+        if elem.limit_expr is not None:
+            has_order_by = bool(elem.order_by_exprs)
+            select += " " + self.offset_limit(0, elem.limit_expr, has_order_by=has_order_by)
+
+        if parent_c.in_select:
+            select = f"({select}) {c.new_unique_name()}"
+        elif parent_c.in_join:
+            select = f"({select})"
+        return select
+
+    def render_join(self, parent_c: Compiler, elem: Join) -> str:
+        tables = [
+            t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in elem.source_tables
+        ]
+        c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
+        op = " JOIN " if elem.op is None else f" {elem.op} JOIN "
+        joined = op.join(self.compile(c, t) for t in tables)
+
+        if elem.on_exprs:
+            on = " AND ".join(self.compile(c, e) for e in elem.on_exprs)
+            res = f"{joined} ON {on}"
+        else:
+            res = joined
+
+        compile_fn = functools.partial(self.compile, c)
+        columns = "*" if elem.columns is None else ", ".join(map(compile_fn, elem.columns))
+        select = f"SELECT {columns} FROM {res}"
+
+        if parent_c.in_select:
+            select = f"({select}) {c.new_unique_name()}"
+        elif parent_c.in_join:
+            select = f"({select})"
+        return select
+
+    def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
+        compile_fn = functools.partial(self.compile, c)
+
+        if elem.values is None:
+            raise CompileError(".group_by() must be followed by a call to .agg()")
+
+        keys = [str(i + 1) for i in range(len(elem.keys))]
+        columns = (elem.keys or []) + (elem.values or [])
+        if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None:
+            return self.compile(
+                c,
+                elem.table.replace(
+                    columns=columns,
+                    group_by_exprs=[Code(k) for k in keys],
+                    having_exprs=elem.having_exprs,
+                )
+            )
+
+        keys_str = ", ".join(keys)
+        columns_str = ", ".join(self.compile(c, x) for x in columns)
+        having_str = (
+            " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else ""
+        )
+        select = (
+            f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"
+        )
+
+        if c.in_select:
+            select = f"({select}) {c.new_unique_name()}"
+        elif c.in_join:
+            select = f"({select})"
+        return select
+
+    def render_in(self, c: Compiler, elem: In) -> str:
+        compile_fn = functools.partial(self.compile, c)
+        elems = ", ".join(map(compile_fn, elem.list))
+        return f"({self.compile(c, elem.expr)} IN ({elems}))"
+
+    def render_cast(self, c: Compiler, elem: Cast) -> str:
+        return f"cast({self.compile(c, elem.expr)} as {self.compile(c, elem.target_type)})"
+
+    def render_random(self, c: Compiler, elem: Random) -> str:
+        return self.random()
+
+    def render_explain(self, c: Compiler, elem: Explain) -> str:
+        return self.explain_as_text(self.compile(c, elem.select))
+
+    def render_currenttimestamp(self, c: Compiler, elem: CurrentTimestamp) -> str:
+        return self.current_timestamp()
+
+    def render_timetravel(self, c: Compiler, elem: TimeTravel) -> str:
+        assert isinstance(c, AbstractMixin_TimeTravel)
+        return self.compile(
+            c,
+            # TODO: why is it c.? why not self? time-trvelling is the dialect's thing, isnt't it?
+            c.time_travel(
+                elem.table, before=elem.before, timestamp=elem.timestamp, offset=elem.offset, statement=elem.statement
+            )
+        )
+
+    def render_createtable(self, c: Compiler, elem: CreateTable) -> str:
+        ne = "IF NOT EXISTS " if elem.if_not_exists else ""
+        if elem.source_table:
+            return f"CREATE TABLE {ne}{self.compile(c, elem.path)} AS {self.compile(c, elem.source_table)}"
+
+        schema = ", ".join(f"{self.quote(k)} {self.type_repr(v)}" for k, v in elem.path.schema.items())
+        pks = (
+            ", PRIMARY KEY (%s)" % ", ".join(elem.primary_keys)
+            if elem.primary_keys and self.SUPPORTS_PRIMARY_KEY
+            else ""
+        )
+        return f"CREATE TABLE {ne}{self.compile(c, elem.path)}({schema}{pks})"
+
+    def render_droptable(self, c: Compiler, elem: DropTable) -> str:
+        ie = "IF EXISTS " if elem.if_exists else ""
+        return f"DROP TABLE {ie}{self.compile(c, elem.path)}"
+
+    def render_truncatetable(self, c: Compiler, elem: TruncateTable) -> str:
+        return f"TRUNCATE TABLE {self.compile(c, elem.path)}"
+
+    def render_inserttotable(self, c: Compiler, elem: InsertToTable) -> str:
+        if isinstance(elem.expr, ConstantTable):
+            expr = self.constant_values(elem.expr.rows)
+        else:
+            expr = self.compile(c, elem.expr)
+
+        columns = "(%s)" % ", ".join(map(self.quote, elem.columns)) if elem.columns is not None else ""
+
+        return f"INSERT INTO {self.compile(c, elem.path)}{columns} {expr}"
+
     def offset_limit(
         self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
     ) -> str:
@@ -335,8 +767,7 @@ def name(self):
         return type(self).__name__
 
     def compile(self, sql_ast):
-        compiler = Compiler(self)
-        return compiler.compile(sql_ast)
+        return self.dialect.compile(Compiler(self), sql_ast)
 
     def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
         """Query the given SQL code/AST, and attempt to convert the result to type 'res_type'
@@ -359,14 +790,14 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
             else:
                 if res_type is None:
                     res_type = sql_ast.type
-                sql_code = compiler.compile(sql_ast)
+                sql_code = self.compile(sql_ast)
                 if sql_code is SKIP:
                     return SKIP
 
             logger.debug("Running SQL (%s): %s", self.name, sql_code)
 
         if self._interactive and isinstance(sql_ast, Select):
-            explained_sql = compiler.compile(Explain(sql_ast))
+            explained_sql = self.compile(Explain(sql_ast))
             explain = self._query(explained_sql)
             for row in explain:
                 # Most returned a 1-tuple. Presto returns a string
diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py
index 667786a7..c40a2b99 100644
--- a/data_diff/joindiff_tables.py
+++ b/data_diff/joindiff_tables.py
@@ -58,15 +58,15 @@ def sample(table_expr):
 
 def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str:
     db = c.database
-    c = c.replace(root=False)  # we're compiling fragments, not full queries
+    c: Compiler = c.replace(root=False)  # we're compiling fragments, not full queries
     if isinstance(db, BigQuery):
-        return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}"
+        return f"create table {c.dialect.compile(c, path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.dialect.compile(c, expr)}"
     elif isinstance(db, Presto):
-        return f"create table {c.compile(path)} as {c.compile(expr)}"
+        return f"create table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}"
     elif isinstance(db, Oracle):
-        return f"create global temporary table {c.compile(path)} as {c.compile(expr)}"
+        return f"create global temporary table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}"
     else:
-        return f"create temporary table {c.compile(path)} as {c.compile(expr)}"
+        return f"create temporary table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}"
 
 
 def bool_to_int(x):
diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py
index 954783de..0013fef7 100644
--- a/data_diff/queries/ast_classes.py
+++ b/data_diff/queries/ast_classes.py
@@ -5,13 +5,12 @@
 from runtype import dataclass
 from typing_extensions import Self
 
-from data_diff.utils import join_iter, ArithString
+from data_diff.utils import ArithString
 from data_diff.abcs.compiler import Compilable
 from data_diff.abcs.database_types import AbstractTable
-from data_diff.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel
 from data_diff.schema import Schema
 
-from data_diff.queries.compiler import Compiler, cv_params, Root, CompileError
+from data_diff.queries.compiler import Compiler
 from data_diff.queries.base import SKIP, args_as_tuple, SqeletonError
 from data_diff.abcs.database_types import DbPath
 
@@ -24,6 +23,10 @@ class QB_TypeError(QueryBuilderError):
     pass
 
 
+class Root:
+    "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)"
+
+
 class ExprNode(Compilable):
     "Base class for query expression nodes"
 
@@ -54,13 +57,6 @@ class Code(ExprNode, Root):
     code: str
     args: Dict[str, Expr] = None
 
-    def compile(self, c: Compiler) -> str:
-        if not self.args:
-            return self.code
-
-        args = {k: c.compile(v) for k, v in self.args.items()}
-        return self.code.format(**args)
-
 
 def _expr_type(e: Expr) -> type:
     if isinstance(e, ExprNode):
@@ -73,9 +69,6 @@ class Alias(ExprNode):
     expr: Expr
     name: str
 
-    def compile(self, c: Compiler) -> str:
-        return f"{c.compile(self.expr)} AS {c.dialect.quote(self.name)}"
-
     @property
     def type(self):
         return _expr_type(self.expr)
@@ -178,32 +171,13 @@ class Concat(ExprNode):
     exprs: list
     sep: str = None
 
-    def compile(self, c: Compiler) -> str:
-        # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
-        items = [f"coalesce({c.compile(Code(c.dialect.to_string(c.compile(expr))))}, '<null>')" for expr in self.exprs]
-        assert items
-        if len(items) == 1:
-            return items[0]
-
-        if self.sep:
-            items = list(join_iter(f"'{self.sep}'", items))
-        return c.dialect.concat(items)
-
 
 @dataclass
 class Count(ExprNode):
     expr: Expr = None
     distinct: bool = False
-
     type = int
 
-    def compile(self, c: Compiler) -> str:
-        expr = c.compile(self.expr) if self.expr else "*"
-        if self.distinct:
-            return f"count(distinct {expr})"
-
-        return f"count({expr})"
-
 
 class LazyOps:
     def __add__(self, other):
@@ -262,43 +236,24 @@ class TestRegex(ExprNode, LazyOps):
     string: Expr
     pattern: Expr
 
-    def compile(self, c: Compiler) -> str:
-        if not isinstance(c.dialect, AbstractMixin_Regex):
-            raise NotImplementedError(f"No regex implementation for database '{c.database}'")
-        regex = c.dialect.test_regex(self.string, self.pattern)
-        return c.compile(regex)
-
 
 @dataclass(eq=False)
 class Func(ExprNode, LazyOps):
     name: str
     args: Sequence[Expr]
 
-    def compile(self, c: Compiler) -> str:
-        args = ", ".join(c.compile(e) for e in self.args)
-        return f"{self.name}({args})"
-
 
 @dataclass
 class WhenThen(ExprNode):
     when: Expr
     then: Expr
 
-    def compile(self, c: Compiler) -> str:
-        return f"WHEN {c.compile(self.when)} THEN {c.compile(self.then)}"
-
 
 @dataclass
 class CaseWhen(ExprNode):
     cases: Sequence[WhenThen]
     else_expr: Expr = None
 
-    def compile(self, c: Compiler) -> str:
-        assert self.cases
-        when_thens = " ".join(c.compile(case) for case in self.cases)
-        else_expr = (" ELSE " + c.compile(self.else_expr)) if self.else_expr is not None else ""
-        return f"CASE {when_thens}{else_expr} END"
-
     @property
     def type(self):
         then_types = {_expr_type(case.then) for case in self.cases}
@@ -353,21 +308,12 @@ class IsDistinctFrom(ExprNode, LazyOps):
     b: Expr
     type = bool
 
-    def compile(self, c: Compiler) -> str:
-        a = c.dialect.to_comparable(c.compile(self.a), self.a.type)
-        b = c.dialect.to_comparable(c.compile(self.b), self.b.type)
-        return c.dialect.is_distinct_from(a, b)
-
 
 @dataclass(eq=False, order=False)
 class BinOp(ExprNode, LazyOps):
     op: str
     args: Sequence[Expr]
 
-    def compile(self, c: Compiler) -> str:
-        expr = f" {self.op} ".join(c.compile(a) for a in self.args)
-        return f"({expr})"
-
     @property
     def type(self):
         types = {_expr_type(i) for i in self.args}
@@ -382,9 +328,6 @@ class UnaryOp(ExprNode, LazyOps):
     op: str
     expr: Expr
 
-    def compile(self, c: Compiler) -> str:
-        return f"({self.op}{c.compile(self.expr)})"
-
 
 class BinBoolOp(BinOp):
     type = bool
@@ -401,22 +344,6 @@ def type(self):
             raise QueryBuilderError(f"Schema required for table {self.source_table}")
         return self.source_table.schema[self.name]
 
-    def compile(self, c: Compiler) -> str:
-        if c._table_context:
-            if len(c._table_context) > 1:
-                aliases = [
-                    t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is self.source_table
-                ]
-                if not aliases:
-                    return c.dialect.quote(self.name)
-                elif len(aliases) > 1:
-                    raise CompileError(f"Too many aliases for column {self.name}")
-                (alias,) = aliases
-
-                return f"{c.dialect.quote(alias.name)}.{c.dialect.quote(self.name)}"
-
-        return c.dialect.quote(self.name)
-
 
 @dataclass
 class TablePath(ExprNode, ITable):
@@ -427,10 +354,6 @@ class TablePath(ExprNode, ITable):
     def source_table(self) -> Self:
         return self
 
-    def compile(self, c: Compiler) -> str:
-        path = self.path  # c.database._normalize_table_path(self.name)
-        return ".".join(map(c.dialect.quote, path))
-
     # Statement shorthands
     def create(self, source_table: ITable = None, *, if_not_exists: bool = False, primary_keys: List[str] = None):
         """Returns a query expression to create a new table.
@@ -514,9 +437,6 @@ class TableAlias(ExprNode, ITable):
     source_table: ITable
     name: str
 
-    def compile(self, c: Compiler) -> str:
-        return f"{c.compile(self.source_table)} {c.dialect.quote(self.name)}"
-
 
 @dataclass
 class Join(ExprNode, ITable, Root):
@@ -564,29 +484,6 @@ def select(self, *exprs, **named_exprs) -> Union[Self, ITable]:
         # TODO Ensure exprs <= self.columns ?
         return self.replace(columns=exprs)
 
-    def compile(self, parent_c: Compiler) -> str:
-        tables = [
-            t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables
-        ]
-        c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
-        op = " JOIN " if self.op is None else f" {self.op} JOIN "
-        joined = op.join(c.compile(t) for t in tables)
-
-        if self.on_exprs:
-            on = " AND ".join(c.compile(e) for e in self.on_exprs)
-            res = f"{joined} ON {on}"
-        else:
-            res = joined
-
-        columns = "*" if self.columns is None else ", ".join(map(c.compile, self.columns))
-        select = f"SELECT {columns} FROM {res}"
-
-        if parent_c.in_select:
-            select = f"({select}) {c.new_unique_name()}"
-        elif parent_c.in_join:
-            select = f"({select})"
-        return select
-
 
 @dataclass
 class GroupBy(ExprNode, ITable, Root):
@@ -619,36 +516,6 @@ def agg(self, *exprs) -> Self:
         resolve_names(self.table, exprs)
         return self.replace(values=(self.values or []) + exprs)
 
-    def compile(self, c: Compiler) -> str:
-        if self.values is None:
-            raise CompileError(".group_by() must be followed by a call to .agg()")
-
-        keys = [str(i + 1) for i in range(len(self.keys))]
-        columns = (self.keys or []) + (self.values or [])
-        if isinstance(self.table, Select) and self.table.columns is None and self.table.group_by_exprs is None:
-            return c.compile(
-                self.table.replace(
-                    columns=columns,
-                    group_by_exprs=[Code(k) for k in keys],
-                    having_exprs=self.having_exprs,
-                )
-            )
-
-        keys_str = ", ".join(keys)
-        columns_str = ", ".join(c.compile(x) for x in columns)
-        having_str = (
-            " HAVING " + " AND ".join(map(c.compile, self.having_exprs)) if self.having_exprs is not None else ""
-        )
-        select = (
-            f"SELECT {columns_str} FROM {c.replace(in_select=True).compile(self.table)} GROUP BY {keys_str}{having_str}"
-        )
-
-        if c.in_select:
-            select = f"({select}) {c.new_unique_name()}"
-        elif c.in_join:
-            select = f"({select})"
-        return select
-
 
 @dataclass
 class TableOp(ExprNode, ITable, Root):
@@ -672,15 +539,6 @@ def schema(self):
         assert len(s1) == len(s2)
         return s1
 
-    def compile(self, parent_c: Compiler) -> str:
-        c = parent_c.replace(in_select=False)
-        table_expr = f"{c.compile(self.table1)} {self.op} {c.compile(self.table2)}"
-        if parent_c.in_select:
-            table_expr = f"({table_expr}) {c.new_unique_name()}"
-        elif parent_c.in_join:
-            table_expr = f"({table_expr})"
-        return table_expr
-
 
 @dataclass
 class Select(ExprNode, ITable, Root):
@@ -705,42 +563,6 @@ def schema(self):
     def source_table(self):
         return self
 
-    def compile(self, parent_c: Compiler) -> str:
-        c = parent_c.replace(in_select=True)  # .add_table_context(self.table)
-
-        columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*"
-        distinct = "DISTINCT " if self.distinct else ""
-        optimizer_hints = c.dialect.optimizer_hints(self.optimizer_hints) if self.optimizer_hints else ""
-        select = f"SELECT {optimizer_hints}{distinct}{columns}"
-
-        if self.table:
-            select += " FROM " + c.compile(self.table)
-        elif c.dialect.PLACEHOLDER_TABLE:
-            select += f" FROM {c.dialect.PLACEHOLDER_TABLE}"
-
-        if self.where_exprs:
-            select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs))
-
-        if self.group_by_exprs:
-            select += " GROUP BY " + ", ".join(map(c.compile, self.group_by_exprs))
-
-        if self.having_exprs:
-            assert self.group_by_exprs
-            select += " HAVING " + " AND ".join(map(c.compile, self.having_exprs))
-
-        if self.order_by_exprs:
-            select += " ORDER BY " + ", ".join(map(c.compile, self.order_by_exprs))
-
-        if self.limit_expr is not None:
-            has_order_by = bool(self.order_by_exprs)
-            select += " " + c.dialect.offset_limit(0, self.limit_expr, has_order_by=has_order_by)
-
-        if parent_c.in_select:
-            select = f"({select}) {c.new_unique_name()}"
-        elif parent_c.in_join:
-            select = f"({select})"
-        return select
-
     @classmethod
     def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, **kwargs):
         assert "table" not in kwargs
@@ -783,16 +605,6 @@ class Cte(ExprNode, ITable):
     name: str = None
     params: Sequence[str] = None
 
-    def compile(self, parent_c: Compiler) -> str:
-        c = parent_c.replace(_table_context=[], in_select=False)
-        compiled = c.compile(self.source_table)
-
-        name = self.name or parent_c.new_unique_name()
-        name_params = f"{name}({', '.join(self.params)})" if self.params else name
-        parent_c._subqueries[name_params] = compiled
-
-        return name
-
     @property
     def schema(self):
         # TODO add cte to schema
@@ -829,9 +641,6 @@ def _get_resolved(self) -> Expr:
             raise QueryBuilderError(f"Column not resolved: {self.resolve_name}")
         return self.resolved
 
-    def compile(self, c: Compiler) -> str:
-        return self._get_resolved().compile(c)
-
     @property
     def type(self):
         return self._get_resolved().type
@@ -860,58 +669,34 @@ def __getitem__(self, name):
 class In(ExprNode):
     expr: Expr
     list: Sequence[Expr]
-
     type = bool
 
-    def compile(self, c: Compiler):
-        elems = ", ".join(map(c.compile, self.list))
-        return f"({c.compile(self.expr)} IN ({elems}))"
-
 
 @dataclass
 class Cast(ExprNode):
     expr: Expr
     target_type: Expr
 
-    def compile(self, c: Compiler) -> str:
-        return f"cast({c.compile(self.expr)} as {c.compile(self.target_type)})"
-
 
 @dataclass
 class Random(ExprNode, LazyOps):
     type = float
 
-    def compile(self, c: Compiler) -> str:
-        return c.dialect.random()
-
 
 @dataclass
 class ConstantTable(ExprNode):
     rows: Sequence[Sequence]
 
-    def compile(self, c: Compiler) -> str:
-        raise NotImplementedError()
-
-    def compile_for_insert(self, c: Compiler):
-        return c.dialect.constant_values(self.rows)
-
 
 @dataclass
 class Explain(ExprNode, Root):
     select: Select
-
     type = str
 
-    def compile(self, c: Compiler) -> str:
-        return c.dialect.explain_as_text(c.compile(self.select))
-
 
 class CurrentTimestamp(ExprNode):
     type = datetime
 
-    def compile(self, c: Compiler) -> str:
-        return c.dialect.current_timestamp()
-
 
 @dataclass
 class TimeTravel(ITable):
@@ -921,14 +706,6 @@ class TimeTravel(ITable):
     offset: int = None
     statement: str = None
 
-    def compile(self, c: Compiler) -> str:
-        assert isinstance(c, AbstractMixin_TimeTravel)
-        return c.compile(
-            c.time_travel(
-                self.table, before=self.before, timestamp=self.timestamp, offset=self.offset, statement=self.statement
-            )
-        )
-
 
 # DDL
 
@@ -944,37 +721,17 @@ class CreateTable(Statement):
     if_not_exists: bool = False
     primary_keys: List[str] = None
 
-    def compile(self, c: Compiler) -> str:
-        ne = "IF NOT EXISTS " if self.if_not_exists else ""
-        if self.source_table:
-            return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}"
-
-        schema = ", ".join(f"{c.dialect.quote(k)} {c.dialect.type_repr(v)}" for k, v in self.path.schema.items())
-        pks = (
-            ", PRIMARY KEY (%s)" % ", ".join(self.primary_keys)
-            if self.primary_keys and c.dialect.SUPPORTS_PRIMARY_KEY
-            else ""
-        )
-        return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})"
-
 
 @dataclass
 class DropTable(Statement):
     path: TablePath
     if_exists: bool = False
 
-    def compile(self, c: Compiler) -> str:
-        ie = "IF EXISTS " if self.if_exists else ""
-        return f"DROP TABLE {ie}{c.compile(self.path)}"
-
 
 @dataclass
 class TruncateTable(Statement):
     path: TablePath
 
-    def compile(self, c: Compiler) -> str:
-        return f"TRUNCATE TABLE {c.compile(self.path)}"
-
 
 @dataclass
 class InsertToTable(Statement):
@@ -983,16 +740,6 @@ class InsertToTable(Statement):
     columns: List[str] = None
     returning_exprs: List[str] = None
 
-    def compile(self, c: Compiler) -> str:
-        if isinstance(self.expr, ConstantTable):
-            expr = self.expr.compile_for_insert(c)
-        else:
-            expr = c.compile(self.expr)
-
-        columns = "(%s)" % ", ".join(map(c.dialect.quote, self.columns)) if self.columns is not None else ""
-
-        return f"INSERT INTO {c.compile(self.path)}{columns} {expr}"
-
     def returning(self, *exprs) -> Self:
         """Add a 'RETURNING' clause to the insert expression.
 
@@ -1014,20 +761,12 @@ def returning(self, *exprs) -> Self:
 class Commit(Statement):
     """Generate a COMMIT statement, if we're in the middle of a transaction, or in auto-commit. Otherwise SKIP."""
 
-    def compile(self, c: Compiler) -> str:
-        return "COMMIT" if not c.database.is_autocommit else SKIP
-
 
 @dataclass
 class Param(ExprNode, ITable):
     """A value placeholder, to be specified at compilation time using the `cv_params` context variable."""
-
     name: str
 
     @property
     def source_table(self):
         return self
-
-    def compile(self, c: Compiler) -> str:
-        params = cv_params.get()
-        return c._compile(params[self.name])
diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py
index b0df04eb..224ad636 100644
--- a/data_diff/queries/compiler.py
+++ b/data_diff/queries/compiler.py
@@ -1,30 +1,30 @@
 import random
 from dataclasses import field
-from datetime import datetime
 from typing import Any, Dict, Sequence, List
 
 from runtype import dataclass
 from typing_extensions import Self
 
-from data_diff.utils import ArithString
 from data_diff.abcs.database_types import AbstractDatabase, AbstractDialect, DbPath
-from data_diff.abcs.compiler import AbstractCompiler, Compilable
-
-import contextvars
-
-cv_params = contextvars.ContextVar("params")
+from data_diff.abcs.compiler import AbstractCompiler
 
 
 class CompileError(Exception):
     pass
 
 
-class Root:
-    "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)"
-
-
 @dataclass
 class Compiler(AbstractCompiler):
+    """
+    Compiler bears the context for a single compilation.
+
+    There can be multiple compilation per app run.
+    There can be multiple compilers in one compilation (with varying contexts).
+    """
+
+    # Database is needed to normalize tables. Dialect is needed for recursive compilations.
+    # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
+    # In practice, we currently bind the dialects to the specific database classes.
     database: AbstractDatabase
 
     in_select: bool = False  # Compilation runtime flag
@@ -40,38 +40,9 @@ class Compiler(AbstractCompiler):
     def dialect(self) -> AbstractDialect:
         return self.database.dialect
 
+    # TODO: DEPRECATED: Remove once the dialect is used directly in all places.
     def compile(self, elem, params=None) -> str:
-        if params:
-            cv_params.set(params)
-
-        if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root):
-            from data_diff.queries.ast_classes import Select
-
-            elem = Select(columns=[elem])
-
-        res = self._compile(elem)
-        if self.root and self._subqueries:
-            subq = ", ".join(f"\n  {k} AS ({v})" for k, v in self._subqueries.items())
-            self._subqueries.clear()
-            return f"WITH {subq}\n{res}"
-        return res
-
-    def _compile(self, elem) -> str:
-        if elem is None:
-            return "NULL"
-        elif isinstance(elem, Compilable):
-            return elem.compile(self.replace(root=False))
-        elif isinstance(elem, str):
-            return f"'{elem}'"
-        elif isinstance(elem, (int, float)):
-            return str(elem)
-        elif isinstance(elem, datetime):
-            return self.dialect.timestamp_value(elem)
-        elif isinstance(elem, bytes):
-            return f"b'{elem.decode()}'"
-        elif isinstance(elem, ArithString):
-            return f"'{elem}'"
-        assert False, elem
+        return self.dialect.compile(self, elem, params)
 
     def new_unique_name(self, prefix="tmp"):
         self._counter[0] += 1
diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py
index 8e916601..bb0c8299 100644
--- a/data_diff/queries/extras.py
+++ b/data_diff/queries/extras.py
@@ -1,12 +1,10 @@
 "Useful AST classes that don't quite fall within the scope of regular SQL"
-
 from typing import Callable, Sequence
 from runtype import dataclass
 
-from data_diff.abcs.database_types import ColType, Native_UUID
+from data_diff.abcs.database_types import ColType
 
-from data_diff.queries.compiler import Compiler
-from data_diff.queries.ast_classes import Expr, ExprNode, Concat, Code
+from data_diff.queries.ast_classes import Expr, ExprNode
 
 
 @dataclass
@@ -15,48 +13,13 @@ class NormalizeAsString(ExprNode):
     expr_type: ColType = None
     type = str
 
-    def compile(self, c: Compiler) -> str:
-        expr = c.compile(self.expr)
-        return c.dialect.normalize_value_by_type(expr, self.expr_type or self.expr.type)
-
 
 @dataclass
 class ApplyFuncAndNormalizeAsString(ExprNode):
     expr: ExprNode
     apply_func: Callable = None
 
-    def compile(self, c: Compiler) -> str:
-        expr = self.expr
-        expr_type = expr.type
-
-        if isinstance(expr_type, Native_UUID):
-            # Normalize first, apply template after (for uuids)
-            # Needed because min/max(uuid) fails in postgresql
-            expr = NormalizeAsString(expr, expr_type)
-            if self.apply_func is not None:
-                expr = self.apply_func(expr)  # Apply template using Python's string formatting
-
-        else:
-            # Apply template before normalizing (for ints)
-            if self.apply_func is not None:
-                expr = self.apply_func(expr)  # Apply template using Python's string formatting
-            expr = NormalizeAsString(expr, expr_type)
-
-        return c.compile(expr)
-
 
 @dataclass
 class Checksum(ExprNode):
     exprs: Sequence[Expr]
-
-    def compile(self, c: Compiler):
-        if len(self.exprs) > 1:
-            exprs = [Code(f"coalesce({c.compile(expr)}, '<null>')") for expr in self.exprs]
-            # exprs = [c.compile(e) for e in exprs]
-            expr = Concat(exprs, "|")
-        else:
-            # No need to coalesce - safe to assume that key cannot be null
-            (expr,) = self.exprs
-        expr = c.compile(expr)
-        md5 = c.dialect.md5_as_int(expr)
-        return f"sum({md5})"