Skip to content

Commit f92af4d

Browse files
fix: improve MSSQL bulk copy handling and fallback to superclass implementation
1 parent f5ec3ff commit f92af4d

File tree

3 files changed

+92
-76
lines changed

3 files changed

+92
-76
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ module = [
194194
"databricks_cli.*",
195195
"mysql.*",
196196
"pymssql.*",
197+
"pyodbc.*",
197198
"psycopg2.*",
198199
"langchain.*",
199200
"pytest_lazyfixture.*",

sqlmesh/core/config/connection.py

Lines changed: 82 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,9 +1345,13 @@ def _validate_auth_configuration(cls, data: t.Any) -> t.Any:
13451345
if driver == "pyodbc" and auth_type and auth_type != "default":
13461346
if auth_type == "service_principal":
13471347
if not data.get("tenant_id") or not data.get("client_id"):
1348-
raise ConfigError("Service principal authentication requires tenant_id and client_id")
1348+
raise ConfigError(
1349+
"Service principal authentication requires tenant_id and client_id"
1350+
)
13491351
if not data.get("client_secret") and not data.get("certificate_path"):
1350-
raise ConfigError("Service principal authentication requires either client_secret or certificate_path")
1352+
raise ConfigError(
1353+
"Service principal authentication requires either client_secret or certificate_path"
1354+
)
13511355
elif auth_type == "msi" and data.get("msi_client_id") and not data.get("client_id"):
13521356
# If msi_client_id is provided, copy it to client_id for consistency
13531357
data["client_id"] = data["msi_client_id"]
@@ -1370,18 +1374,20 @@ def _connection_kwargs_keys(self) -> t.Set[str]:
13701374
"autocommit",
13711375
"tds_version",
13721376
}
1373-
1377+
13741378
if self.driver == "pyodbc":
1375-
base_keys.update({
1376-
"driver_name",
1377-
"trust_server_certificate",
1378-
"encrypt",
1379-
"odbc_properties",
1380-
})
1379+
base_keys.update(
1380+
{
1381+
"driver_name",
1382+
"trust_server_certificate",
1383+
"encrypt",
1384+
"odbc_properties",
1385+
}
1386+
)
13811387
# Remove pymssql-specific parameters
13821388
base_keys.discard("tds_version")
13831389
base_keys.discard("conn_properties")
1384-
1390+
13851391
return base_keys
13861392

13871393
@property
@@ -1392,64 +1398,73 @@ def _engine_adapter(self) -> t.Type[EngineAdapter]:
13921398
def _connection_factory(self) -> t.Callable:
13931399
if self.driver == "pymssql":
13941400
import pymssql
1401+
13951402
return pymssql.connect
1396-
else: # pyodbc
1397-
import pyodbc
1398-
1399-
def connect(**kwargs: t.Any) -> t.Callable:
1400-
# Extract parameters for connection string
1401-
host = kwargs.pop("host")
1402-
port = kwargs.pop("port", 1433)
1403-
database = kwargs.pop("database", "")
1404-
user = kwargs.pop("user", None)
1405-
password = kwargs.pop("password", None)
1406-
driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server")
1407-
trust_server_certificate = kwargs.pop("trust_server_certificate", False)
1408-
encrypt = kwargs.pop("encrypt", True)
1409-
login_timeout = kwargs.pop("login_timeout", 60)
1410-
1411-
# Build connection string
1412-
conn_str_parts = [
1413-
f"DRIVER={{{driver_name}}}",
1414-
f"SERVER={host},{port}",
1415-
]
1416-
1417-
if database:
1418-
conn_str_parts.append(f"DATABASE={database}")
1419-
1420-
# Add security options
1421-
conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}")
1422-
if trust_server_certificate:
1423-
conn_str_parts.append("TrustServerCertificate=YES")
1424-
1425-
conn_str_parts.append(f"Connection Timeout={login_timeout}")
1426-
1427-
# Standard SQL Server authentication
1428-
if user:
1429-
conn_str_parts.append(f"UID={user}")
1430-
if password:
1431-
conn_str_parts.append(f"PWD={password}")
1432-
1433-
# Add any additional ODBC properties from the odbc_properties dictionary
1434-
if self.odbc_properties:
1435-
for key, value in self.odbc_properties.items():
1436-
# Skip properties that we've already set above
1437-
if key.lower() in ('driver', 'server', 'database', 'uid', 'pwd',
1438-
'encrypt', 'trustservercertificate', 'connection timeout'):
1439-
continue
1440-
1441-
# Handle boolean values properly
1442-
if isinstance(value, bool):
1443-
conn_str_parts.append(f"{key}={'YES' if value else 'NO'}")
1444-
else:
1445-
conn_str_parts.append(f"{key}={value}")
1446-
1447-
# Create the connection string
1448-
conn_str = ";".join(conn_str_parts)
1449-
1450-
return pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False))
1451-
1452-
return connect
1403+
# pyodbc
1404+
import pyodbc
1405+
1406+
def connect(**kwargs: t.Any) -> t.Callable:
1407+
# Extract parameters for connection string
1408+
host = kwargs.pop("host")
1409+
port = kwargs.pop("port", 1433)
1410+
database = kwargs.pop("database", "")
1411+
user = kwargs.pop("user", None)
1412+
password = kwargs.pop("password", None)
1413+
driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server")
1414+
trust_server_certificate = kwargs.pop("trust_server_certificate", False)
1415+
encrypt = kwargs.pop("encrypt", True)
1416+
login_timeout = kwargs.pop("login_timeout", 60)
1417+
1418+
# Build connection string
1419+
conn_str_parts = [
1420+
f"DRIVER={{{driver_name}}}",
1421+
f"SERVER={host},{port}",
1422+
]
1423+
1424+
if database:
1425+
conn_str_parts.append(f"DATABASE={database}")
1426+
1427+
# Add security options
1428+
conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}")
1429+
if trust_server_certificate:
1430+
conn_str_parts.append("TrustServerCertificate=YES")
1431+
1432+
conn_str_parts.append(f"Connection Timeout={login_timeout}")
1433+
1434+
# Standard SQL Server authentication
1435+
if user:
1436+
conn_str_parts.append(f"UID={user}")
1437+
if password:
1438+
conn_str_parts.append(f"PWD={password}")
1439+
1440+
# Add any additional ODBC properties from the odbc_properties dictionary
1441+
if self.odbc_properties:
1442+
for key, value in self.odbc_properties.items():
1443+
# Skip properties that we've already set above
1444+
if key.lower() in (
1445+
"driver",
1446+
"server",
1447+
"database",
1448+
"uid",
1449+
"pwd",
1450+
"encrypt",
1451+
"trustservercertificate",
1452+
"connection timeout",
1453+
):
1454+
continue
1455+
1456+
# Handle boolean values properly
1457+
if isinstance(value, bool):
1458+
conn_str_parts.append(f"{key}={'YES' if value else 'NO'}")
1459+
else:
1460+
conn_str_parts.append(f"{key}={value}")
1461+
1462+
# Create the connection string
1463+
conn_str = ";".join(conn_str_parts)
1464+
1465+
return pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False))
1466+
1467+
return connect
14531468

14541469
@property
14551470
def _extra_engine_config(self) -> t.Dict[str, t.Any]:

sqlmesh/core/engine_adapter/mssql.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ def _df_to_source_queries(
212212
assert isinstance(df, pd.DataFrame)
213213
temp_table = self._get_temp_table(target_table or "pandas")
214214

215+
# Return the superclass implementation if the connection pool doesn't support bulk_copy
216+
if not hasattr(self._connection_pool.get(), "bulk_copy"):
217+
return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table)
218+
215219
def query_factory() -> Query:
216220
# It is possible for the factory to be called multiple times and if so then the temp table will already
217221
# be created so we skip creating again. This means we are assuming the first call is the same result
@@ -222,15 +226,11 @@ def query_factory() -> Query:
222226
self.create_table(temp_table, columns_to_types_create)
223227
conn = self._connection_pool.get()
224228

225-
if hasattr(conn, 'bulk_copy'):
226-
# Use bulk_copy if available
227-
rows: t.List[t.Tuple[t.Any, ...]] = list(
228-
df.replace({np.nan: None}).itertuples(index=False, name=None) # type: ignore
229-
)
230-
conn.bulk_copy(temp_table.sql(dialect=self.dialect), rows)
231-
else:
232-
# Fallback to the superclass implementation of _df_to_source_queries if bulk_copy is not available
233-
return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table)
229+
rows: t.List[t.Tuple[t.Any, ...]] = list(
230+
df.replace({np.nan: None}).itertuples(index=False, name=None) # type: ignore
231+
)
232+
conn.bulk_copy(temp_table.sql(dialect=self.dialect), rows)
233+
234234
return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) # type: ignore
235235

236236
return [

0 commit comments

Comments
 (0)