15
15
# specific language governing permissions and limitations
16
16
# under the License.
17
17
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
18
+ from __future__ import annotations
19
+
18
20
import copy
19
21
from collections import namedtuple
20
22
from datetime import datetime
@@ -719,7 +721,15 @@ def test_adjust_engine_params_catalog_only() -> None:
719
721
assert str (uri ) == "trino://user:pass@localhost:8080/new_catalog/new_schema"
720
722
721
723
722
- def test_get_default_catalog () -> None :
724
+ @pytest .mark .parametrize (
725
+ "sqlalchemy_uri,result" ,
726
+ [
727
+ ("trino://user:pass@localhost:8080/system" , "system" ),
728
+ ("trino://user:pass@localhost:8080/system/default" , "system" ),
729
+ ("trino://trino@localhost:8081" , None ),
730
+ ],
731
+ )
732
+ def test_get_default_catalog (sqlalchemy_uri : str , result : str | None ) -> None :
723
733
"""
724
734
Test the ``get_default_catalog`` method.
725
735
"""
@@ -728,15 +738,9 @@ def test_get_default_catalog() -> None:
728
738
729
739
database = Database (
730
740
database_name = "my_db" ,
731
- sqlalchemy_uri = "trino://user:pass@localhost:8080/system" ,
732
- )
733
- assert TrinoEngineSpec .get_default_catalog (database ) == "system"
734
-
735
- database = Database (
736
- database_name = "my_db" ,
737
- sqlalchemy_uri = "trino://user:pass@localhost:8080/system/default" ,
741
+ sqlalchemy_uri = sqlalchemy_uri ,
738
742
)
739
- assert TrinoEngineSpec .get_default_catalog (database ) == "system"
743
+ assert TrinoEngineSpec .get_default_catalog (database ) == result
740
744
741
745
742
746
@patch ("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition" )
0 commit comments