Skip to content

Commit 2fd78db

Browse files
committed
Create indexes as needed
1 parent 137cc60 commit 2fd78db

File tree

2 files changed

+70
-13
lines changed

2 files changed

+70
-13
lines changed

subsetter/common.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import os
3-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
3+
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
44

55
import sqlalchemy as sa
66
from pydantic import BaseModel
@@ -217,3 +217,35 @@ def _set_session_sqls(dbapi_connection, _):
217217
cursor.close()
218218

219219
return engine
220+
221+
222+
def pydantic_search(root: Any) -> Iterable[BaseModel]:
223+
"""
224+
A generator that yields all sub-models found underneath the passed root object (including the
225+
root object itself). Searches model fields as well as through lists and dicts found in those
226+
fields.
227+
"""
228+
vis = set()
229+
stack = []
230+
231+
def _push(key: Any, value: Any):
232+
if isinstance(value, (BaseModel, list, dict)):
233+
if id(value) not in vis:
234+
vis.add(id(value))
235+
stack.append(value)
236+
237+
_push(None, root)
238+
while stack:
239+
data = stack.pop()
240+
if isinstance(data, BaseModel):
241+
yield data
242+
for field, _ in data.model_fields.items():
243+
_push(field, getattr(data, field))
244+
245+
if isinstance(data, list):
246+
for idx, elem in enumerate(data):
247+
_push(idx, elem)
248+
249+
if isinstance(data, dict):
250+
for key, elem in data.items():
251+
_push(key, elem)

subsetter/sampler.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
import collections
23
import functools
34
import json
45
import logging
@@ -12,7 +13,7 @@
1213
from sqlalchemy.sql.compiler import SQLCompiler
1314
from sqlalchemy.sql.expression import ClauseElement, Executable
1415

15-
from subsetter.common import DatabaseConfig, parse_table_name
16+
from subsetter.common import DatabaseConfig, parse_table_name, pydantic_search
1617
from subsetter.config_model import (
1718
ConflictStrategy,
1819
DatabaseOutputConfig,
@@ -21,7 +22,7 @@
2122
)
2223
from subsetter.filters import FilterOmit, FilterView, FilterViewChain
2324
from subsetter.metadata import DatabaseMetadata
24-
from subsetter.plan_model import SQLTableIdentifier
25+
from subsetter.plan_model import SQLLeftJoin, SQLTableIdentifier
2526
from subsetter.planner import SubsetPlan
2627
from subsetter.solver import toposort
2728

@@ -69,7 +70,7 @@ def create(
6970
select: sa.Select,
7071
*,
7172
name: str = "",
72-
primary_key: Tuple[str, ...] = (),
73+
indexes: Iterable[Tuple[str, ...]] = (),
7374
) -> Tuple[sa.Table, int]:
7475
"""
7576
Create a temporary table on the passed connection generated by the passed
@@ -82,9 +83,8 @@ def create(
8283
schema: The schema to create the temporary table within. For some dialects
8384
temporary tables always exist in their own schema and this parameter
8485
will be ignored.
85-
primary_key: If set will mark the set of columns passed as primary keys in
86-
the temporary table. This tuple should match a subset of the
87-
column names in the select query.
86+
indexes: creates an index on each tuple of columns listed. This is useful
87+
if future queries are likely to reference these columns.
8888
8989
Returns a tuple containing the generated table object and the number of rows that
9090
were inserted in the table.
@@ -106,10 +106,7 @@ def create(
106106
metadata,
107107
schema=temp_schema,
108108
prefixes=["TEMPORARY"],
109-
*(
110-
sa.Column(col.name, col.type, primary_key=col.name in primary_key)
111-
for col in select.selected_columns
112-
),
109+
*(sa.Column(col.name, col.type) for col in select.selected_columns),
113110
)
114111
try:
115112
metadata.create_all(conn)
@@ -122,6 +119,22 @@ def create(
122119
if "--read-only" not in str(exc):
123120
raise
124121

122+
for idx, index_cols in enumerate(indexes):
123+
# For some dialects/data types we may not be able to construct an index. We just do our
124+
# best here instead of hard failing.
125+
try:
126+
sa.Index(
127+
f"{temp_name}_idx_{idx}",
128+
*(table_obj.columns[col_name] for col_name in index_cols),
129+
).create(bind=conn)
130+
except sa.exc.OperationalError:
131+
LOGGER.warning(
132+
"Failed to create index %s on temporary table %s",
133+
index_cols,
134+
temp_name,
135+
exc_info=True,
136+
)
137+
125138
# Copy data into the temporary table
126139
stmt = table_obj.insert().from_select(list(table_obj.columns), select)
127140
LOGGER.debug(
@@ -834,6 +847,18 @@ def _materialize_tables(
834847
conn: sa.Connection,
835848
plan: SubsetPlan,
836849
) -> None:
850+
# Figure out what sets of columns are going to be queried for our materialized tables.
851+
joined_columns = collections.defaultdict(set)
852+
for data in pydantic_search(plan):
853+
if not isinstance(data, SQLLeftJoin):
854+
continue
855+
table_id = data.right
856+
if not table_id.sampled:
857+
continue
858+
joined_columns[(table_id.table_schema, table_id.table_name)].add(
859+
tuple(data.right_columns)
860+
)
861+
837862
materialization_order = self._materialization_order(meta, plan)
838863
for schema, table_name, ref_count in materialization_order:
839864
table = meta.tables[(schema, table_name)]
@@ -866,7 +891,7 @@ def _materialize_tables(
866891
schema,
867892
table_q,
868893
name=table_name,
869-
primary_key=table.primary_key,
894+
indexes=joined_columns[(schema, table_name)],
870895
)
871896
)
872897
self.cached_table_sizes[(schema, table_name)] = rowcount
@@ -889,7 +914,7 @@ def _materialize_tables(
889914
schema,
890915
meta.temp_tables[(schema, table_name, 0)].select(),
891916
name=table_name,
892-
primary_key=table.primary_key,
917+
indexes=joined_columns[(schema, table_name)],
893918
)
894919
)
895920
LOGGER.info(

0 commit comments

Comments
 (0)