Skip to content

Commit 137cc60

Browse files
committed
change fk query strategy to joins when possible
1 parent 40d3c57 commit 137cc60

File tree

12 files changed

+501
-362
lines changed

12 files changed

+501
-362
lines changed

subsetter/plan_model.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,21 @@ def simplify(self) -> "SQLWhereClause":
232232
]
233233

234234

235+
class SQLLeftJoin(BaseModel):
236+
right: SQLTableIdentifier
237+
left_columns: List[str]
238+
right_columns: List[str]
239+
half_unique: bool = True
240+
241+
235242
class SQLStatementSelect(BaseModel):
236243
type_: Literal["select"] = Field(..., alias="type")
237244
columns: Optional[List[str]] = None
238245
from_: SQLTableIdentifier = Field(..., alias="from")
239246
where: Optional[SQLWhereClause] = None
240247
limit: Optional[int] = None
248+
joins: Optional[List[SQLLeftJoin]] = None
249+
joins_outer: bool = False
241250

242251
model_config = ConfigDict(populate_by_name=True)
243252

@@ -250,6 +259,48 @@ def build(self, context: SQLBuildContext):
250259
else:
251260
stmt = sa.select(table_obj)
252261

262+
if self.joins:
263+
joined_cols: List[sa.ColumnExpression] = []
264+
joined: sa.FromClause = table_obj
265+
exists_constraints: List[sa.ColumnExpressionArgument] = []
266+
for join in self.joins: # pylint: disable=not-an-iterable
267+
right = join.right.build(context).alias()
268+
269+
if join.half_unique:
270+
joined = joined.join(
271+
right,
272+
onclause=sa.and_(
273+
*(
274+
table_obj.c[lft_col] == right.c[rht_col]
275+
for lft_col, rht_col in zip(
276+
join.left_columns, join.right_columns
277+
)
278+
)
279+
),
280+
isouter=self.joins_outer,
281+
)
282+
joined_cols.extend(
283+
right.c[rht_col] for rht_col in join.right_columns
284+
)
285+
else:
286+
exists_constraints.append(
287+
sa.exists().where(
288+
*(
289+
table_obj.c[lft_col] == right.c[rht_col]
290+
for lft_col, rht_col in zip(
291+
join.left_columns, join.right_columns
292+
)
293+
)
294+
)
295+
)
296+
297+
stmt = stmt.select_from(joined).distinct()
298+
if self.joins_outer:
299+
exists_constraints.extend(col.is_not(None) for col in joined_cols)
300+
stmt = stmt.where(sa.or_(*exists_constraints))
301+
elif exists_constraints:
302+
stmt = stmt.where(sa.and_(*exists_constraints))
303+
253304
if self.where:
254305
stmt = stmt.where(self.where.build(context, table_obj))
255306

@@ -273,6 +324,10 @@ def simplify(self) -> "SQLStatementSelect":
273324
kwargs["columns"] = self.columns
274325
if self.limit is not None:
275326
kwargs["limit"] = self.limit
327+
if self.joins:
328+
kwargs["joins"] = self.joins
329+
kwargs["joins_outer"] = self.joins_outer
330+
276331
return SQLStatementSelect(**kwargs) # type: ignore
277332

278333

subsetter/planner.py

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import logging
2-
from typing import Dict, List, Optional, Set, Tuple
2+
from typing import Dict, Iterable, List, Optional, Set, Tuple
3+
4+
import sqlalchemy as sa
35

46
from subsetter.common import DatabaseConfig, parse_table_name
57
from subsetter.config_model import PlannerConfig
68
from subsetter.metadata import DatabaseMetadata, ForeignKey, TableMetadata
79
from subsetter.plan_model import (
10+
SQLLeftJoin,
811
SQLStatementSelect,
912
SQLStatementUnion,
1013
SQLTableIdentifier,
@@ -13,7 +16,6 @@
1316
SQLWhereClauseAnd,
1417
SQLWhereClauseIn,
1518
SQLWhereClauseOperator,
16-
SQLWhereClauseOr,
1719
SQLWhereClauseRandom,
1820
SQLWhereClauseSQL,
1921
SubsetPlan,
@@ -267,8 +269,6 @@ def _plan_table(
267269
processed: Set[Tuple[str, str]],
268270
target: Optional[PlannerConfig.TargetConfig] = None,
269271
) -> SQLTableQuery:
270-
fk_constraints: List[SQLWhereClause] = []
271-
272272
foreign_keys = sorted(
273273
fk
274274
for fk in table.foreign_keys
@@ -311,33 +311,34 @@ def _plan_table(
311311
[f"{fk.dst_schema}.{fk.dst_table}" for fk in rev_foreign_keys],
312312
)
313313

314-
fk_constraints = [
315-
SQLWhereClauseIn(
316-
type_="in",
317-
columns=list(fk.columns),
318-
values=SQLStatementSelect(
319-
type_="select",
320-
columns=list(fk.dst_columns),
321-
from_=SQLTableIdentifier(
314+
def _is_distinct(table_obj: sa.Table, cols: Iterable[str]) -> bool:
315+
cols_st = set(cols)
316+
for constraint in table_obj.constraints:
317+
if isinstance(
318+
constraint, (sa.PrimaryKeyConstraint, sa.UniqueConstraint)
319+
):
320+
constraint_cols = set(col.name for col in constraint.columns)
321+
if constraint_cols <= cols_st:
322+
return True
323+
return False
324+
325+
fk_joins = []
326+
for fk in foreign_keys or rev_foreign_keys:
327+
dst_table = self.meta.tables[(fk.dst_schema, fk.dst_table)]
328+
half_unique = _is_distinct(table.table_obj, fk.columns) or _is_distinct(
329+
dst_table.table_obj, fk.dst_columns
330+
)
331+
fk_joins.append(
332+
SQLLeftJoin(
333+
right=SQLTableIdentifier(
322334
table_schema=fk.dst_schema,
323335
table_name=fk.dst_table,
324336
sampled=True,
325337
),
326-
),
327-
)
328-
for fk in foreign_keys or rev_foreign_keys
329-
]
330-
331-
fk_constraint: SQLWhereClause
332-
if foreign_keys:
333-
fk_constraint = SQLWhereClauseAnd(
334-
type_="and",
335-
conditions=fk_constraints,
336-
)
337-
else:
338-
fk_constraint = SQLWhereClauseOr(
339-
type_="or",
340-
conditions=fk_constraints,
338+
left_columns=list(fk.columns),
339+
right_columns=list(fk.dst_columns),
340+
half_unique=half_unique,
341+
)
341342
)
342343

343344
conf_constraints = self.config.table_constraints.get(
@@ -365,23 +366,25 @@ def _plan_table(
365366
)
366367
)
367368

369+
statements: List[SQLStatementSelect] = []
370+
368371
# Calculate initial foreign-key / config constraint statement
369-
statements: List[SQLStatementSelect] = [
370-
SQLStatementSelect(
371-
type_="select",
372-
from_=SQLTableIdentifier(
373-
table_schema=table.schema,
374-
table_name=table.name,
375-
),
376-
where=SQLWhereClauseAnd(
377-
type_="and",
378-
conditions=[
379-
*conf_constraints_sql,
380-
fk_constraint,
381-
],
382-
),
372+
if foreign_keys or rev_foreign_keys:
373+
statements.append(
374+
SQLStatementSelect(
375+
type_="select",
376+
from_=SQLTableIdentifier(
377+
table_schema=table.schema,
378+
table_name=table.name,
379+
),
380+
joins=fk_joins,
381+
joins_outer=not foreign_keys,
382+
where=SQLWhereClauseAnd(
383+
type_="and",
384+
conditions=conf_constraints_sql,
385+
),
386+
)
383387
)
384-
]
385388

386389
# If targetted also calculate target constraint statement
387390
if target:

subsetter/sampler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,12 @@ def create(
123123
raise
124124

125125
# Copy data into the temporary table
126-
result = conn.execute(
127-
table_obj.insert().from_select(list(table_obj.columns), select)
126+
stmt = table_obj.insert().from_select(list(table_obj.columns), select)
127+
LOGGER.debug(
128+
" Using statement %s",
129+
str(stmt.compile(dialect=conn.engine.dialect)).replace("\n", " "),
128130
)
129-
result = conn.execute(table_obj.select())
131+
result = conn.execute(stmt)
130132

131133
return table_obj, result.rowcount
132134

tests/data/big_join.yaml

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
dataset: big_join
2+
3+
plan_config:
4+
targets:
5+
test.homes:
6+
in:
7+
state: [MI, CA]
8+
extra_fks:
9+
- src_table: test.homes
10+
src_columns: [state]
11+
dst_table: test.users
12+
dst_columns: [state]
13+
select:
14+
- test.*
15+
16+
sample_config: {}
17+
18+
expected_plan:
19+
queries:
20+
test.homes:
21+
statement:
22+
from:
23+
schema: test
24+
table: homes
25+
type: select
26+
where:
27+
columns:
28+
- state
29+
type: in
30+
values:
31+
- - MI
32+
- - CA
33+
test.users:
34+
statement:
35+
from:
36+
schema: test
37+
table: users
38+
joins:
39+
- half_unique: false
40+
left_columns:
41+
- state
42+
right:
43+
sampled: true
44+
schema: test
45+
table: homes
46+
right_columns:
47+
- state
48+
joins_outer: true
49+
type: select
50+
51+
expected_sample:
52+
test_out.homes:
53+
- id: 1
54+
name: home1
55+
state: MI
56+
- id: 2
57+
name: home2
58+
state: MI
59+
- id: 4
60+
name: home4
61+
state: MI
62+
- id: 5
63+
name: home5
64+
state: MI
65+
- id: 7
66+
name: home7
67+
state: MI
68+
- id: 8
69+
name: home8
70+
state: MI
71+
- id: 11
72+
name: home11
73+
state: CA
74+
- id: 12
75+
name: home12
76+
state: CA
77+
test_out.users:
78+
- id: 1
79+
name: john
80+
state: MI
81+
- id: 3
82+
name: richard
83+
state: CA
84+
- id: 5
85+
name: ashley
86+
state: MI
87+
- id: 6
88+
name: corey
89+
state: MI
90+
- id: 7
91+
name: teresa
92+
state: MI
93+
- id: 8
94+
name: jake
95+
state: MI
96+
- id: 10
97+
name: holly
98+
state: CA

tests/data/datasets/big_join.yaml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
tables:
2+
test.users:
3+
primary_key: [id]
4+
columns: [id, name|str, state|str]
5+
test.homes:
6+
primary_key: [id]
7+
columns: [id, name|str, state|str]
8+
9+
data:
10+
test.users:
11+
- [1, john, MI]
12+
- [2, peter, WI]
13+
- [3, richard, CA]
14+
- [4, jeff, NY]
15+
- [5, ashley, MI]
16+
- [6, corey, MI]
17+
- [7, teresa, MI]
18+
- [8, jake, MI]
19+
- [9, erin, WI]
20+
- [10, holly, CA]
21+
test.homes:
22+
- [1, home1, MI]
23+
- [2, home2, MI]
24+
- [3, home3, WI]
25+
- [4, home4, MI]
26+
- [5, home5, MI]
27+
- [6, home6, WI]
28+
- [7, home7, MI]
29+
- [8, home8, MI]
30+
- [9, home9, NY]
31+
- [10, home10, OR]
32+
- [11, home11, CA]
33+
- [12, home12, CA]

0 commit comments

Comments
 (0)