Skip to content

Commit e62007b

Browse files
authored
Improve config error checking (#34)
1 parent 11d6dd6 commit e62007b

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

subsetter/config_model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, List, Literal, Optional, Union
22

3-
from pydantic import BaseModel, ConfigDict, Field
3+
from pydantic import BaseModel, ConfigDict, Field, model_validator
44
from typing_extensions import Annotated
55

66
from subsetter.common import DatabaseConfig, SQLKnownOperator, SQLLiteralType
@@ -30,6 +30,19 @@ class ExtraFKConfig(ForbidBaseModel):
3030
dst_table: str
3131
dst_columns: List[str]
3232

33+
@model_validator(mode="after")
34+
def check_columns_match(self):
35+
col_count = len(self.src_columns)
36+
if not col_count:
37+
raise ValueError("src_columns cannot be empty")
38+
if len(self.dst_columns) != col_count:
39+
raise ValueError("src_columns and dst_columns must be the same length")
40+
if len(set(self.src_columns)) != col_count:
41+
raise ValueError("each column in src_columns must be unique")
42+
if len(set(self.dst_columns)) != col_count:
43+
raise ValueError("each column in src_columns must be unique")
44+
return self
45+
3346
class ColumnConstraint(ForbidBaseModel):
3447
column: str
3548
operator: SQLKnownOperator

subsetter/planner.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,22 +138,56 @@ def _solve_order(self) -> List[str]:
138138

139139
def _add_extra_fks(self) -> None:
140140
"""Add in additional foreign keys requested."""
141-
for extra_fk in self.config.extra_fks:
141+
for index, extra_fk in enumerate(self.config.extra_fks):
142142
src_schema, src_table_name = parse_table_name(extra_fk.src_table)
143143
dst_schema, dst_table_name = parse_table_name(extra_fk.dst_table)
144144
table = self.meta.tables.get((src_schema, src_table_name))
145145
if table is None:
146146
LOGGER.warning(
147-
"Found no source table %s.%s referenced in add_extra_fks",
147+
"Found no source table %s.%s referenced in extra_fks[%d]",
148148
src_schema,
149149
src_table_name,
150+
index,
150151
)
151152
continue
152-
if (dst_schema, dst_table_name) not in self.meta.tables:
153+
154+
src_missing_cols = {
155+
col
156+
for col in extra_fk.src_columns
157+
if col not in table.table_obj.columns
158+
}
159+
if src_missing_cols:
160+
LOGGER.warning(
161+
"Columns %s do not exist in %s.%s referenced in extra_fks[%d]",
162+
src_missing_cols,
163+
src_schema,
164+
src_table_name,
165+
index,
166+
)
167+
continue
168+
169+
dst_table = self.meta.tables.get((dst_schema, dst_table_name))
170+
if dst_table is None:
171+
LOGGER.warning(
172+
"Found no destination table %s.%s referenced in add_extra_fks[%d]",
173+
dst_schema,
174+
dst_table_name,
175+
index,
176+
)
177+
continue
178+
179+
dst_missing_cols = {
180+
col
181+
for col in extra_fk.dst_columns
182+
if col not in dst_table.table_obj.columns
183+
}
184+
if dst_missing_cols:
153185
LOGGER.warning(
154-
"Found no destination table %s.%s referenced in add_extra_fks",
186+
"Columns %s do not exist in %s.%s referenced in extra_fks[%d]",
187+
dst_missing_cols,
155188
dst_schema,
156189
dst_table_name,
190+
index,
157191
)
158192
continue
159193

0 commit comments

Comments
 (0)