1
1
from dataclasses import field
2
2
from datetime import datetime
3
- from typing import Any , Generator , List , Optional , Sequence , Union , Dict
3
+ from typing import Any , Generator , List , Optional , Sequence , Type , Union , Dict
4
4
5
5
from runtype import dataclass
6
+ from typing_extensions import Self
6
7
7
8
from ..utils import join_iter , ArithString
8
9
from ..abcs import Compilable
@@ -322,7 +323,7 @@ def when(self, *whens: Expr) -> "QB_When":
322
323
return QB_When (self , whens [0 ])
323
324
return QB_When (self , BinBoolOp ("AND" , whens ))
324
325
325
- def else_ (self , then : Expr ):
326
+ def else_ (self , then : Expr ) -> Self :
326
327
"""Add an 'else' clause to the case expression.
327
328
328
329
Can only be called once!
@@ -422,7 +423,7 @@ class TablePath(ExprNode, ITable):
422
423
schema : Optional [Schema ] = field (default = None , repr = False )
423
424
424
425
@property
425
- def source_table (self ):
426
+ def source_table (self ) -> Self :
426
427
return self
427
428
428
429
def compile (self , c : Compiler ) -> str :
@@ -524,7 +525,7 @@ class Join(ExprNode, ITable, Root):
524
525
columns : Sequence [Expr ] = None
525
526
526
527
@property
527
- def source_table (self ):
528
+ def source_table (self ) -> Self :
528
529
return self
529
530
530
531
@property
@@ -533,7 +534,7 @@ def schema(self):
533
534
s = self .source_tables [0 ].schema # TODO validate types match between both tables
534
535
return type (s )({c .name : c .type for c in self .columns })
535
536
536
- def on (self , * exprs ) -> "Join" :
537
+ def on (self , * exprs ) -> Self :
537
538
"""Add an ON clause, for filtering the result of the cartesian product (i.e. the JOIN)"""
538
539
if len (exprs ) == 1 :
539
540
(e ,) = exprs
@@ -546,7 +547,7 @@ def on(self, *exprs) -> "Join":
546
547
547
548
return self .replace (on_exprs = (self .on_exprs or []) + exprs )
548
549
549
- def select (self , * exprs , ** named_exprs ) -> ITable :
550
+ def select (self , * exprs , ** named_exprs ) -> Union [ Self , ITable ] :
550
551
"""Select fields to return from the JOIN operation
551
552
552
553
See Also: ``ITable.select()``
@@ -600,7 +601,7 @@ def source_table(self):
600
601
def __post_init__ (self ):
601
602
assert self .keys or self .values
602
603
603
- def having (self , * exprs ):
604
+ def having (self , * exprs ) -> Self :
604
605
"""Add a 'HAVING' clause to the group-by"""
605
606
exprs = args_as_tuple (exprs )
606
607
exprs = _drop_skips (exprs )
@@ -610,7 +611,7 @@ def having(self, *exprs):
610
611
resolve_names (self .table , exprs )
611
612
return self .replace (having_exprs = (self .having_exprs or []) + exprs )
612
613
613
- def agg (self , * exprs ):
614
+ def agg (self , * exprs ) -> Self :
614
615
"""Select aggregated fields for the group-by."""
615
616
exprs = args_as_tuple (exprs )
616
617
exprs = _drop_skips (exprs )
@@ -991,7 +992,7 @@ def compile(self, c: Compiler) -> str:
991
992
992
993
return f"INSERT INTO { c .compile (self .path )} { columns } { expr } "
993
994
994
- def returning (self , * exprs ):
995
+ def returning (self , * exprs ) -> Self :
995
996
"""Add a 'RETURNING' clause to the insert expression.
996
997
997
998
Note: Not all databases support this feature!
0 commit comments