1
1
import abc
2
+ import collections
2
3
import functools
3
4
import json
4
5
import logging
12
13
from sqlalchemy .sql .compiler import SQLCompiler
13
14
from sqlalchemy .sql .expression import ClauseElement , Executable
14
15
15
- from subsetter .common import DatabaseConfig , parse_table_name
16
+ from subsetter .common import DatabaseConfig , parse_table_name , pydantic_search
16
17
from subsetter .config_model import (
17
18
ConflictStrategy ,
18
19
DatabaseOutputConfig ,
21
22
)
22
23
from subsetter .filters import FilterOmit , FilterView , FilterViewChain
23
24
from subsetter .metadata import DatabaseMetadata
24
- from subsetter .plan_model import SQLTableIdentifier
25
+ from subsetter .plan_model import SQLLeftJoin , SQLTableIdentifier
25
26
from subsetter .planner import SubsetPlan
26
27
from subsetter .solver import toposort
27
28
@@ -69,7 +70,7 @@ def create(
69
70
select : sa .Select ,
70
71
* ,
71
72
name : str = "" ,
72
- primary_key : Tuple [str , ...] = (),
73
+ indexes : Iterable [ Tuple [str , ...] ] = (),
73
74
) -> Tuple [sa .Table , int ]:
74
75
"""
75
76
Create a temporary table on the passed connection generated by the passed
@@ -82,9 +83,8 @@ def create(
82
83
schema: The schema to create the temporary table within. For some dialects
83
84
temporary tables always exist in their own schema and this parameter
84
85
will be ignored.
85
- primary_key: If set will mark the set of columns passed as primary keys in
86
- the temporary table. This tuple should match a subset of the
87
- column names in the select query.
86
+ indexes: creates an index on each tuple of columns listed. This is useful
87
+ if future queries are likely to reference these columns.
88
88
89
89
Returns a tuple containing the generated table object and the number of rows that
90
90
were inserted in the table.
@@ -106,10 +106,7 @@ def create(
106
106
metadata ,
107
107
schema = temp_schema ,
108
108
prefixes = ["TEMPORARY" ],
109
- * (
110
- sa .Column (col .name , col .type , primary_key = col .name in primary_key )
111
- for col in select .selected_columns
112
- ),
109
+ * (sa .Column (col .name , col .type ) for col in select .selected_columns ),
113
110
)
114
111
try :
115
112
metadata .create_all (conn )
@@ -122,6 +119,22 @@ def create(
122
119
if "--read-only" not in str (exc ):
123
120
raise
124
121
122
+ for idx , index_cols in enumerate (indexes ):
123
+ # For some dialects/data types we may not be able to construct an index. We just do our
124
+ # best here instead of hard failing.
125
+ try :
126
+ sa .Index (
127
+ f"{ temp_name } _idx_{ idx } " ,
128
+ * (table_obj .columns [col_name ] for col_name in index_cols ),
129
+ ).create (bind = conn )
130
+ except sa .exc .OperationalError :
131
+ LOGGER .warning (
132
+ "Failed to create index %s on temporary table %s" ,
133
+ index_cols ,
134
+ temp_name ,
135
+ exc_info = True ,
136
+ )
137
+
125
138
# Copy data into the temporary table
126
139
stmt = table_obj .insert ().from_select (list (table_obj .columns ), select )
127
140
LOGGER .debug (
@@ -834,6 +847,18 @@ def _materialize_tables(
834
847
conn : sa .Connection ,
835
848
plan : SubsetPlan ,
836
849
) -> None :
850
+ # Figure out what sets of columns are going to be queried for our materialized tables.
851
+ joined_columns = collections .defaultdict (set )
852
+ for data in pydantic_search (plan ):
853
+ if not isinstance (data , SQLLeftJoin ):
854
+ continue
855
+ table_id = data .right
856
+ if not table_id .sampled :
857
+ continue
858
+ joined_columns [(table_id .table_schema , table_id .table_name )].add (
859
+ tuple (data .right_columns )
860
+ )
861
+
837
862
materialization_order = self ._materialization_order (meta , plan )
838
863
for schema , table_name , ref_count in materialization_order :
839
864
table = meta .tables [(schema , table_name )]
@@ -866,7 +891,7 @@ def _materialize_tables(
866
891
schema ,
867
892
table_q ,
868
893
name = table_name ,
869
- primary_key = table . primary_key ,
894
+ indexes = joined_columns [( schema , table_name )] ,
870
895
)
871
896
)
872
897
self .cached_table_sizes [(schema , table_name )] = rowcount
@@ -889,7 +914,7 @@ def _materialize_tables(
889
914
schema ,
890
915
meta .temp_tables [(schema , table_name , 0 )].select (),
891
916
name = table_name ,
892
- primary_key = table . primary_key ,
917
+ indexes = joined_columns [( schema , table_name )] ,
893
918
)
894
919
)
895
920
LOGGER .info (
0 commit comments