diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index c531c102..97a966b5 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -22,11 +22,9 @@ import datetime from decimal import Decimal import random -import operator import uuid from google import auth -import google.api_core.exceptions from google.cloud.bigquery import dbapi from google.cloud.bigquery.table import ( RangePartitioning, @@ -1054,11 +1052,6 @@ def dbapi(cls): def import_dbapi(cls): return dbapi - @staticmethod - def _build_formatted_table_id(table): - """Build '.' string using given table.""" - return "{}.{}".format(table.reference.dataset_id, table.table_id) - @staticmethod def _add_default_dataset_to_job_config(job_config, project_id, dataset_id): # If dataset_id is set, then we know the job_config isn't None @@ -1107,36 +1100,34 @@ def create_connect_args(self, url): ) return ([], {"client": client}) - def _get_table_or_view_names(self, connection, item_types, schema=None): - current_schema = schema or self.dataset_id - get_table_name = ( - self._build_formatted_table_id - if self.dataset_id is None - else operator.attrgetter("table_id") - ) + def _get_default_schema_name(self, connection) -> str: + return connection.dialect.dataset_id + def _get_table_or_view_names(self, connection, item_types, schema=None): client = connection.connection._client - datasets = client.list_datasets() - - result = [] - for dataset in datasets: - if current_schema is not None and current_schema != dataset.dataset_id: - continue - - try: - tables = client.list_tables( - dataset.reference, page_size=self.list_tables_page_size + # `schema=None` means to search the default schema. If one isn't set in the + # connection string, then we have nothing to search so return an empty list. + # + # When using Alembic with `include_schemas=False`, it expects to compare to a + # single schema. If `include_schemas=True`, it will enumerate all schemas and + # then call `get_table_names`/`get_view_names` for each schema. + current_schema = schema or self.default_schema_name + if current_schema is None: + return [] + try: + return [ + table.table_id + for table in client.list_tables( + current_schema, page_size=self.list_tables_page_size ) - for table in tables: - if table.table_type in item_types: - result.append(get_table_name(table)) - except google.api_core.exceptions.NotFound: - # It's possible that the dataset was deleted between when we - # fetched the list of datasets and when we try to list the - # tables from it. See: - # https://github.com/googleapis/python-bigquery-sqlalchemy/issues/105 - pass - return result + if table.table_type in item_types + ] + except NotFound: + # It's possible that the dataset was deleted between when we + # fetched the list of datasets and when we try to list the + # tables from it. See: + # https://github.com/googleapis/python-bigquery-sqlalchemy/issues/105 + return [] @staticmethod def _split_table_name(full_table_name): diff --git a/tests/system/test_sqlalchemy_bigquery.py b/tests/system/test_sqlalchemy_bigquery.py index 7ea4ccc6..29fd4e8c 100644 --- a/tests/system/test_sqlalchemy_bigquery.py +++ b/tests/system/test_sqlalchemy_bigquery.py @@ -366,18 +366,6 @@ def test_reflect_dataset_does_not_exist(engine): ) -def test_tables_list(engine, engine_using_test_dataset, bigquery_dataset): - tables = sqlalchemy.inspect(engine).get_table_names() - assert f"{bigquery_dataset}.sample" in tables - assert f"{bigquery_dataset}.sample_one_row" in tables - assert f"{bigquery_dataset}.sample_view" not in tables - - tables = sqlalchemy.inspect(engine_using_test_dataset).get_table_names() - assert "sample" in tables - assert "sample_one_row" in tables - assert "sample_view" not in tables - - def test_group_by(session, table, session_using_test_dataset, table_using_test_dataset): """labels in SELECT clause should be correclty formatted (dots are replaced with underscores)""" for session, table in [ @@ -612,14 +600,15 @@ def test_schemas_names(inspector, inspector_using_test_dataset, bigquery_dataset assert f"{bigquery_dataset}" in datasets -def test_table_names_in_schema( - inspector, inspector_using_test_dataset, bigquery_dataset -): +def test_table_names(inspector, inspector_using_test_dataset, bigquery_dataset): + tables = inspector.get_table_names() + assert not tables + tables = inspector.get_table_names(bigquery_dataset) - assert f"{bigquery_dataset}.sample" in tables - assert f"{bigquery_dataset}.sample_one_row" in tables - assert f"{bigquery_dataset}.sample_dml_empty" in tables - assert f"{bigquery_dataset}.sample_view" not in tables + assert "sample" in tables + assert "sample_one_row" in tables + assert "sample_dml_empty" in tables + assert "sample_view" not in tables assert len(tables) == 3 tables = inspector_using_test_dataset.get_table_names() @@ -632,8 +621,11 @@ def test_table_names_in_schema( def test_view_names(inspector, inspector_using_test_dataset, bigquery_dataset): view_names = inspector.get_view_names() - assert f"{bigquery_dataset}.sample_view" in view_names - assert f"{bigquery_dataset}.sample" not in view_names + assert not view_names + + view_names = inspector.get_view_names(bigquery_dataset) + assert "sample_view" in view_names + assert "sample" not in view_names view_names = inspector_using_test_dataset.get_view_names() assert "sample_view" in view_names diff --git a/tests/unit/fauxdbi.py b/tests/unit/fauxdbi.py index 4d8f02b6..afcdb2c6 100644 --- a/tests/unit/fauxdbi.py +++ b/tests/unit/fauxdbi.py @@ -482,8 +482,8 @@ def list_tables(self, dataset, page_size): google.cloud.bigquery.table.TableListItem( dict( tableReference=dict( - projectId=dataset.project, - datasetId=dataset.dataset_id, + projectId="myproject", + datasetId=dataset, tableId=row["name"], ), type=row["type"].upper(), diff --git a/tests/unit/test_sqlalchemy_bigquery.py b/tests/unit/test_sqlalchemy_bigquery.py index db20e2f0..0086e08a 100644 --- a/tests/unit/test_sqlalchemy_bigquery.py +++ b/tests/unit/test_sqlalchemy_bigquery.py @@ -65,83 +65,70 @@ def table_item(dataset_id, table_id, type_="TABLE"): @pytest.mark.parametrize( - ["datasets_list", "tables_lists", "expected"], + ["dataset", "tables_list", "expected"], [ - ([], [], []), - ([dataset_item("dataset_1")], [[]], []), + (None, [], []), + ("dataset", [], []), ( - [dataset_item("dataset_1"), dataset_item("dataset_2")], + "dataset", [ - [table_item("dataset_1", "d1t1"), table_item("dataset_1", "d1t2")], - [ - table_item("dataset_2", "d2t1"), - table_item("dataset_2", "d2view", type_="VIEW"), - table_item("dataset_2", "d2ext", type_="EXTERNAL"), - table_item("dataset_2", "d2mv", type_="MATERIALIZED_VIEW"), - ], + table_item("dataset", "t1"), + table_item("dataset", "view", type_="VIEW"), + table_item("dataset", "ext", type_="EXTERNAL"), + table_item("dataset", "mv", type_="MATERIALIZED_VIEW"), ], - ["dataset_1.d1t1", "dataset_1.d1t2", "dataset_2.d2t1", "dataset_2.d2ext"], + ["t1", "ext"], ), ( - [dataset_item("dataset_1"), dataset_item("dataset_deleted")], - [ - [table_item("dataset_1", "d1t1")], - google.api_core.exceptions.NotFound("dataset_deleted"), - ], - ["dataset_1.d1t1"], + "dataset", + google.api_core.exceptions.NotFound("dataset_deleted"), + [], ), ], ) def test_get_table_names( - engine_under_test, mock_bigquery_client, datasets_list, tables_lists, expected + engine_under_test, mock_bigquery_client, dataset, tables_list, expected ): - mock_bigquery_client.list_datasets.return_value = datasets_list - mock_bigquery_client.list_tables.side_effect = tables_lists - table_names = sqlalchemy.inspect(engine_under_test).get_table_names() - mock_bigquery_client.list_datasets.assert_called_once() - assert mock_bigquery_client.list_tables.call_count == len(datasets_list) + mock_bigquery_client.list_tables.side_effect = [tables_list] + table_names = sqlalchemy.inspect(engine_under_test).get_table_names(schema=dataset) + if dataset: + mock_bigquery_client.list_tables.assert_called_once() + else: + mock_bigquery_client.list_tables.assert_not_called() assert list(sorted(table_names)) == list(sorted(expected)) @pytest.mark.parametrize( - ["datasets_list", "tables_lists", "expected"], + ["dataset", "tables_list", "expected"], [ - ([], [], []), - ([dataset_item("dataset_1")], [[]], []), + (None, [], []), + ("dataset", [], []), ( - [dataset_item("dataset_1"), dataset_item("dataset_2")], + "dataset", [ - [ - table_item("dataset_1", "d1t1"), - table_item("dataset_1", "d1view", type_="VIEW"), - ], - [ - table_item("dataset_2", "d2t1"), - table_item("dataset_2", "d2view", type_="VIEW"), - table_item("dataset_2", "d2ext", type_="EXTERNAL"), - table_item("dataset_2", "d2mv", type_="MATERIALIZED_VIEW"), - ], + table_item("dataset", "t1"), + table_item("dataset", "view", type_="VIEW"), + table_item("dataset", "ext", type_="EXTERNAL"), + table_item("dataset", "mv", type_="MATERIALIZED_VIEW"), ], - ["dataset_1.d1view", "dataset_2.d2view", "dataset_2.d2mv"], + ["view", "mv"], ), ( - [dataset_item("dataset_1"), dataset_item("dataset_deleted")], - [ - [table_item("dataset_1", "d1view", type_="VIEW")], - google.api_core.exceptions.NotFound("dataset_deleted"), - ], - ["dataset_1.d1view"], + "dataset_deleted", + google.api_core.exceptions.NotFound("dataset_deleted"), + [], ), ], ) def test_get_view_names( - inspector_under_test, mock_bigquery_client, datasets_list, tables_lists, expected + inspector_under_test, mock_bigquery_client, dataset, tables_list, expected ): - mock_bigquery_client.list_datasets.return_value = datasets_list - mock_bigquery_client.list_tables.side_effect = tables_lists - view_names = inspector_under_test.get_view_names() - mock_bigquery_client.list_datasets.assert_called_once() - assert mock_bigquery_client.list_tables.call_count == len(datasets_list) + mock_bigquery_client.list_tables.side_effect = [tables_list] + view_names = inspector_under_test.get_view_names(schema=dataset) + if dataset: + mock_bigquery_client.list_tables.assert_called_once() + else: + mock_bigquery_client.list_tables.assert_not_called() assert list(sorted(view_names)) == list(sorted(expected))