diff --git a/docs/integrations/engines/azuresql.md b/docs/integrations/engines/azuresql.md index e9b97abaa1..5b54ffa9c6 100644 --- a/docs/integrations/engines/azuresql.md +++ b/docs/integrations/engines/azuresql.md @@ -2,15 +2,18 @@ [Azure SQL](https://azure.microsoft.com/en-us/products/azure-sql) is "a family of managed, secure, and intelligent products that use the SQL Server database engine in the Azure cloud." -The Azure SQL adapter only supports authentication with a username and password. It does not support authentication with Microsoft Entra or Azure Active Directory. - ## Local/Built-in Scheduler **Engine Adapter Type**: `azuresql` ### Installation +#### User / Password Authentication: ``` pip install "sqlmesh[azuresql]" ``` +#### Microsoft Entra ID / Azure Active Directory Authentication: +``` +pip install "sqlmesh[azuresql-odbc]" +``` ### Connection options @@ -18,8 +21,8 @@ pip install "sqlmesh[azuresql]" | ----------------- | ---------------------------------------------------------------- | :----------: | :------: | | `type` | Engine type name - must be `azuresql` | string | Y | | `host` | The hostname of the Azure SQL server | string | Y | -| `user` | The username to use for authentication with the Azure SQL server | string | N | -| `password` | The password to use for authentication with the Azure SQL server | string | N | +| `user` | The username / client ID to use for authentication with the Azure SQL server | string | N | +| `password` | The password / client secret to use for authentication with the Azure SQL server | string | N | | `port` | The port number of the Azure SQL server | int | N | | `database` | The target database | string | N | | `charset` | The character set used for the connection | string | N | @@ -27,4 +30,7 @@ pip install "sqlmesh[azuresql]" | `login_timeout` | The timeout for connection and login in seconds. Default: 60 | int | N | | `appname` | The application name to use for the connection | string | N | | `conn_properties` | The list of connection properties | list[string] | N | -| `autocommit` | Is autocommit mode enabled. Default: false | bool | N | \ No newline at end of file +| `autocommit` | Is autocommit mode enabled. Default: false | bool | N | +| `driver` | The driver to use for the connection. Default: pymssql | string | N | +| `driver_name` | The driver name to use for the connection. E.g., *ODBC Driver 18 for SQL Server* | string | N | +| `odbc_properties` | The dict of ODBC connection properties. E.g., authentication: ActiveDirectoryServicePrincipal. See more [here](https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver16). | dict | N | \ No newline at end of file diff --git a/docs/integrations/engines/mssql.md b/docs/integrations/engines/mssql.md index 1650319d07..f06b5f1387 100644 --- a/docs/integrations/engines/mssql.md +++ b/docs/integrations/engines/mssql.md @@ -4,9 +4,14 @@ **Engine Adapter Type**: `mssql` ### Installation +#### User / Password Authentication: ``` pip install "sqlmesh[mssql]" ``` +#### Microsoft Entra ID / Azure Active Directory Authentication: +``` +pip install "sqlmesh[mssql-odbc]" +``` ### Connection options @@ -14,8 +19,8 @@ pip install "sqlmesh[mssql]" | ----------------- | ------------------------------------------------------------ | :----------: | :------: | | `type` | Engine type name - must be `mssql` | string | Y | | `host` | The hostname of the MSSQL server | string | Y | -| `user` | The username to use for authentication with the MSSQL server | string | N | -| `password` | The password to use for authentication with the MSSQL server | string | N | +| `user` | The username / client id to use for authentication with the MSSQL server | string | N | +| `password` | The password / client secret to use for authentication with the MSSQL server | string | N | | `port` | The port number of the MSSQL server | int | N | | `database` | The target database | string | N | | `charset` | The character set used for the connection | string | N | @@ -24,3 +29,6 @@ pip install "sqlmesh[mssql]" | `appname` | The application name to use for the connection | string | N | | `conn_properties` | The list of connection properties | list[string] | N | | `autocommit` | Is autocommit mode enabled. Default: false | bool | N | +| `driver` | The driver to use for the connection. Default: pymssql | string | N | +| `driver_name` | The driver name to use for the connection. E.g., *ODBC Driver 18 for SQL Server* | string | N | +| `odbc_properties` | The dict of ODBC connection properties. E.g., authentication: ActiveDirectoryServicePrincipal. See more [here](https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver16). | dict | N | \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cfbe1cb293..80b7f1df6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ classifiers = [ [project.optional-dependencies] athena = ["PyAthena[Pandas]"] azuresql = ["pymssql"] +azuresql-odbc = ["pyodbc"] bigquery = [ "google-cloud-bigquery[pandas]", "google-cloud-bigquery-storage" @@ -99,6 +100,7 @@ gcppostgres = ["cloud-sql-python-connector[pg8000]>=1.8.0"] github = ["PyGithub~=2.5.0"] llm = ["langchain", "openai"] mssql = ["pymssql"] +mssql-odbc = ["pyodbc"] mysql = ["pymysql"] mwaa = ["boto3"] postgres = ["psycopg2"] @@ -192,6 +194,7 @@ module = [ "databricks_cli.*", "mysql.*", "pymssql.*", + "pyodbc.*", "psycopg2.*", "langchain.*", "pytest_lazyfixture.*", diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index be0eee114a..ad3cc60729 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -1278,6 +1278,33 @@ def _connection_factory(self) -> t.Callable: class MSSQLConnectionConfig(ConnectionConfig): + """Configuration for the MSSQL connection. + + Args: + host: The hostname of the MSSQL server (required). + user: The username for authentication (optional when using alternative authentication methods). + password: The password for authentication (optional when using alternative authentication methods). + database: The target database (optional). + port: The server port (default: 1433). + timeout: Query timeout in seconds (default: 0, meaning no timeout). + login_timeout: Connection and login timeout (default: 60 seconds). + charset: Character set (default: "UTF-8"). + appname: Application name. + conn_properties: Connection properties. + autocommit: Autocommit mode (default: False). + tds_version: TDS protocol version. + driver: Connection driver to use, either "pymssql" or "pyodbc" (default: "pymssql"). + driver_name: ODBC driver name when using pyodbc (e.g., "ODBC Driver 18 for SQL Server"). + trust_server_certificate: Whether to trust the server certificate (for pyodbc). + encrypt: Whether to encrypt the connection (for pyodbc). + odbc_properties: Dictionary of arbitrary ODBC connection properties that will be passed directly to the connection string. + See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute for available options. + This can be used for authentication methods like Microsoft Entra ID (formerly Azure AD). + concurrent_tasks: The maximum number of tasks that can use this connection concurrently. + register_comments: Whether or not to register model comments with the SQL engine. + pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. + """ + host: str user: t.Optional[str] = None password: t.Optional[str] = None @@ -1290,6 +1317,15 @@ class MSSQLConnectionConfig(ConnectionConfig): conn_properties: t.Optional[t.Union[t.List[str], str]] = None autocommit: t.Optional[bool] = False tds_version: t.Optional[str] = None + # Driver options + driver: t.Literal["pymssql", "pyodbc"] = "pymssql" + # PyODBC specific options + driver_name: t.Optional[str] = None # e.g. "ODBC Driver 18 for SQL Server" + trust_server_certificate: t.Optional[bool] = None + encrypt: t.Optional[bool] = None + # Dictionary of arbitrary ODBC connection properties + # See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute + odbc_properties: t.Optional[t.Dict[str, t.Any]] = None concurrent_tasks: int = 4 register_comments: bool = True @@ -1297,9 +1333,34 @@ class MSSQLConnectionConfig(ConnectionConfig): type_: t.Literal["mssql"] = Field(alias="type", default="mssql") + @model_validator(mode="before") + def _validate_auth_configuration(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + driver = data.get("driver", "pymssql") + auth_type = data.get("auth_type") + + # Validate requirements for Entra ID authentication methods when using PyODBC + if driver == "pyodbc" and auth_type and auth_type != "default": + if auth_type == "service_principal": + if not data.get("tenant_id") or not data.get("client_id"): + raise ConfigError( + "Service principal authentication requires tenant_id and client_id" + ) + if not data.get("client_secret") and not data.get("certificate_path"): + raise ConfigError( + "Service principal authentication requires either client_secret or certificate_path" + ) + elif auth_type == "msi" and data.get("msi_client_id") and not data.get("client_id"): + # If msi_client_id is provided, copy it to client_id for consistency + data["client_id"] = data["msi_client_id"] + + return data + @property def _connection_kwargs_keys(self) -> t.Set[str]: - return { + base_keys = { "host", "user", "password", @@ -1314,15 +1375,96 @@ def _connection_kwargs_keys(self) -> t.Set[str]: "tds_version", } + if self.driver == "pyodbc": + base_keys.update( + { + "driver_name", + "trust_server_certificate", + "encrypt", + "odbc_properties", + } + ) + # Remove pymssql-specific parameters + base_keys.discard("tds_version") + base_keys.discard("conn_properties") + + return base_keys + @property def _engine_adapter(self) -> t.Type[EngineAdapter]: return engine_adapter.MSSQLEngineAdapter @property def _connection_factory(self) -> t.Callable: - import pymssql + if self.driver == "pymssql": + import pymssql + + return pymssql.connect + # pyodbc + import pyodbc + + def connect(**kwargs: t.Any) -> t.Callable: + # Extract parameters for connection string + host = kwargs.pop("host") + port = kwargs.pop("port", 1433) + database = kwargs.pop("database", "") + user = kwargs.pop("user", None) + password = kwargs.pop("password", None) + driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server") + trust_server_certificate = kwargs.pop("trust_server_certificate", False) + encrypt = kwargs.pop("encrypt", True) + login_timeout = kwargs.pop("login_timeout", 60) + + # Build connection string + conn_str_parts = [ + f"DRIVER={{{driver_name}}}", + f"SERVER={host},{port}", + ] + + if database: + conn_str_parts.append(f"DATABASE={database}") + + # Add security options + conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}") + if trust_server_certificate: + conn_str_parts.append("TrustServerCertificate=YES") + + conn_str_parts.append(f"Connection Timeout={login_timeout}") + + # Standard SQL Server authentication + if user: + conn_str_parts.append(f"UID={user}") + if password: + conn_str_parts.append(f"PWD={password}") + + # Add any additional ODBC properties from the odbc_properties dictionary + if self.odbc_properties: + for key, value in self.odbc_properties.items(): + # Skip properties that we've already set above + if key.lower() in ( + "driver", + "server", + "database", + "uid", + "pwd", + "encrypt", + "trustservercertificate", + "connection timeout", + ): + continue + + # Handle boolean values properly + if isinstance(value, bool): + conn_str_parts.append(f"{key}={'YES' if value else 'NO'}") + else: + conn_str_parts.append(f"{key}={value}") - return pymssql.connect + # Create the connection string + conn_str = ";".join(conn_str_parts) + + return pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False)) + + return connect @property def _extra_engine_config(self) -> t.Dict[str, t.Any]: diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index f80f1816a3..701d9890b7 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -212,6 +212,10 @@ def _df_to_source_queries( assert isinstance(df, pd.DataFrame) temp_table = self._get_temp_table(target_table or "pandas") + # Return the superclass implementation if the connection pool doesn't support bulk_copy + if not hasattr(self._connection_pool.get(), "bulk_copy"): + return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table) + def query_factory() -> Query: # It is possible for the factory to be called multiple times and if so then the temp table will already # be created so we skip creating again. This means we are assuming the first call is the same result