Skip to content

Commit 88dfd99

Browse files
authored
Use group by instead of distinct (#40)
1 parent 1a3dd26 commit 88dfd99

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

subsetter/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _main_plan(args):
158158
ctx = open(args.plan_output, "w", encoding="utf-8")
159159
with ctx as fplan:
160160
yaml.dump(
161-
plan.dict(exclude_unset=True, by_alias=True),
161+
plan.model_dump(exclude_unset=True, by_alias=True),
162162
stream=fplan,
163163
default_flow_style=False,
164164
width=2**20,

subsetter/plan_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def build(self, context: SQLBuildContext):
266266
for join in self.joins: # pylint: disable=not-an-iterable
267267
right = join.right.build(context).alias()
268268

269-
if join.half_unique:
269+
if join.half_unique and table_obj.primary_key:
270270
joined = joined.join(
271271
right,
272272
onclause=sa.and_(
@@ -294,7 +294,10 @@ def build(self, context: SQLBuildContext):
294294
)
295295
)
296296

297-
stmt = stmt.select_from(joined).distinct()
297+
stmt = stmt.select_from(joined)
298+
if joined is not table_obj:
299+
stmt = stmt.group_by(*table_obj.primary_key.columns)
300+
298301
if self.joins_outer:
299302
exists_constraints.extend(col.is_not(None) for col in joined_cols)
300303
stmt = stmt.where(sa.or_(*exists_constraints))

subsetter/sampler.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def create(
7070
select: sa.Select,
7171
*,
7272
name: str = "",
73+
primary_key: Tuple[str, ...] = (),
7374
indexes: Iterable[Tuple[str, ...]] = (),
7475
) -> Tuple[sa.Table, int]:
7576
"""
@@ -83,6 +84,9 @@ def create(
8384
schema: The schema to create the temporary table within. For some dialects
8485
temporary tables always exist in their own schema and this parameter
8586
will be ignored.
87+
primary_key: If set will mark the set of columns passed as primary keys in
88+
the temporary table. This tuple should match a subset of the
89+
column names in the select query.
8690
indexes: creates an index on each tuple of columns listed. This is useful
8791
if future queries are likely to reference these columns.
8892
@@ -106,7 +110,10 @@ def create(
106110
metadata,
107111
schema=temp_schema,
108112
prefixes=["TEMPORARY"],
109-
*(sa.Column(col.name, col.type) for col in select.selected_columns),
113+
*(
114+
sa.Column(col.name, col.type, primary_key=col.name in primary_key)
115+
for col in select.selected_columns
116+
),
110117
)
111118
try:
112119
metadata.create_all(conn)
@@ -120,6 +127,8 @@ def create(
120127
raise
121128

122129
for idx, index_cols in enumerate(indexes):
130+
if index_cols == primary_key:
131+
continue
123132
# For some dialects/data types we may not be able to construct an index. We just do our
124133
# best here instead of hard failing.
125134
try:
@@ -891,6 +900,7 @@ def _materialize_tables(
891900
schema,
892901
table_q,
893902
name=table_name,
903+
primary_key=table.primary_key,
894904
indexes=joined_columns[(schema, table_name)],
895905
)
896906
)
@@ -914,6 +924,7 @@ def _materialize_tables(
914924
schema,
915925
meta.temp_tables[(schema, table_name, 0)].select(),
916926
name=table_name,
927+
primary_key=table.primary_key,
917928
indexes=joined_columns[(schema, table_name)],
918929
)
919930
)

0 commit comments

Comments
 (0)