diff --git a/dlt/destinations/impl/athena/athena_adapter.py b/dlt/destinations/impl/athena/athena_adapter.py index cb600335c0..50f7abc54a 100644 --- a/dlt/destinations/impl/athena/athena_adapter.py +++ b/dlt/destinations/impl/athena/athena_adapter.py @@ -4,7 +4,7 @@ from dlt.common.pendulum import timezone from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TColumnSchema -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource from dlt.extract.items import TTableHintTemplate @@ -89,7 +89,7 @@ def athena_adapter( >>> athena_adapter(data, partition=["department", athena_partition.year("date_hired"), athena_partition.bucket(8, "name")]) [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} if partition: diff --git a/dlt/destinations/impl/bigquery/bigquery_adapter.py b/dlt/destinations/impl/bigquery/bigquery_adapter.py index 4dee572f57..55fe1b6b74 100644 --- a/dlt/destinations/impl/bigquery/bigquery_adapter.py +++ b/dlt/destinations/impl/bigquery/bigquery_adapter.py @@ -7,7 +7,7 @@ TColumnNames, TTableSchemaColumns, ) -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource from dlt.extract.items import TTableHintTemplate @@ -78,7 +78,7 @@ def bigquery_adapter( >>> bigquery_adapter(data, partition="date_hired", table_expiration_datetime="2024-01-30", table_description="Employee Data") [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} column_hints: TTableSchemaColumns = {} diff --git a/dlt/destinations/impl/clickhouse/clickhouse_adapter.py b/dlt/destinations/impl/clickhouse/clickhouse_adapter.py index dc030ef88c..41be531b71 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse_adapter.py +++ b/dlt/destinations/impl/clickhouse/clickhouse_adapter.py @@ -5,7 +5,7 @@ TABLE_ENGINE_TYPES, TABLE_ENGINE_TYPE_HINT, ) -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource from dlt.extract.items import TTableHintTemplate @@ -46,7 +46,7 @@ def clickhouse_adapter(data: Any, table_engine_type: TTableEngineType = None) -> >>> clickhouse_adapter(data, table_engine_type="merge_tree") [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} if table_engine_type is not None: diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index bb33632b48..99d5ef43c6 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -1,7 +1,7 @@ from typing import Any from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource @@ -32,7 +32,7 @@ def lancedb_adapter( >>> lancedb_adapter(data, embed="description") [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) column_hints: TTableSchemaColumns = {} diff --git a/dlt/destinations/impl/qdrant/qdrant_adapter.py b/dlt/destinations/impl/qdrant/qdrant_adapter.py index 215d87a920..e39d3e3644 100644 --- a/dlt/destinations/impl/qdrant/qdrant_adapter.py +++ b/dlt/destinations/impl/qdrant/qdrant_adapter.py @@ -2,7 +2,7 @@ from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.extract import DltResource, resource as make_resource -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter VECTORIZE_HINT = "x-qdrant-embed" @@ -32,7 +32,7 @@ def qdrant_adapter( >>> qdrant_adapter(data, embed="description") [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) column_hints: TTableSchemaColumns = {} diff --git a/dlt/destinations/impl/synapse/synapse_adapter.py b/dlt/destinations/impl/synapse/synapse_adapter.py index 8b262f3621..e12823c7bf 100644 --- a/dlt/destinations/impl/synapse/synapse_adapter.py +++ b/dlt/destinations/impl/synapse/synapse_adapter.py @@ -3,7 +3,7 @@ from dlt.extract import DltResource, resource as make_resource from dlt.extract.items import TTableHintTemplate from dlt.extract.hints import TResourceHints -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter TTableIndexType = Literal["heap", "clustered_columnstore_index"] """ @@ -37,7 +37,7 @@ def synapse_adapter(data: Any, table_index_type: TTableIndexType = None) -> DltR >>> synapse_adapter(data, table_index_type="clustered_columnstore_index") [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} if table_index_type is not None: diff --git a/dlt/destinations/impl/weaviate/weaviate_adapter.py b/dlt/destinations/impl/weaviate/weaviate_adapter.py index a290ac65b4..9bd0b41783 100644 --- a/dlt/destinations/impl/weaviate/weaviate_adapter.py +++ b/dlt/destinations/impl/weaviate/weaviate_adapter.py @@ -2,7 +2,7 @@ from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.extract import DltResource, resource as make_resource -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter TTokenizationTMethod = Literal["word", "lowercase", "whitespace", "field"] TOKENIZATION_METHODS: Set[TTokenizationTMethod] = set(get_args(TTokenizationTMethod)) @@ -54,7 +54,7 @@ def weaviate_adapter( >>> weaviate_adapter(data, vectorize="description", tokenization={"description": "word"}) [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) column_hints: TTableSchemaColumns = {} if vectorize: diff --git a/dlt/destinations/utils.py b/dlt/destinations/utils.py index 9dd8b83509..fcc2c4fd16 100644 --- a/dlt/destinations/utils.py +++ b/dlt/destinations/utils.py @@ -1,4 +1,6 @@ import re +import inspect + from typing import Any, List, Optional, Tuple from dlt.common import logger @@ -14,16 +16,35 @@ from typing import Any, cast, Tuple, Dict, Type from dlt.destinations.exceptions import DatabaseTransientException -from dlt.extract import DltResource, resource as make_resource +from dlt.extract import DltResource, resource as make_resource, DltSource RE_DATA_TYPE = re.compile(r"([A-Z]+)\((\d+)(?:,\s?(\d+))?\)") -def ensure_resource(data: Any) -> DltResource: - """Wraps `data` in a DltResource if it's not a DltResource already.""" +def get_resource_for_adapter(data: Any) -> DltResource: + """ + Helper function for adapters. Wraps `data` in a DltResource if it's not a DltResource already. + Alternatively if `data` is a DltSource, throws an error if there are multiple resource in the source + or returns the single resource if available. + """ if isinstance(data, DltResource): return data - resource_name = None if hasattr(data, "__name__") else "content" + # prevent accidentally wrapping sources with adapters + if isinstance(data, DltSource): + if len(data.selected_resources.keys()) == 1: + return list(data.selected_resources.values())[0] + else: + raise ValueError( + "You are trying to use an adapter on a DltSource with multiple resources. You can" + " only use adapters on pure data, direclty on a DltResouce or a DltSource" + " containing a single DltResource." + ) + + resource_name = None + if not hasattr(data, "__name__"): + logger.info("Setting default resource name to `content` for adapted resource.") + resource_name = "content" + return cast(DltResource, make_resource(data, name=resource_name)) diff --git a/docs/examples/qdrant_zendesk/qdrant_zendesk.py b/docs/examples/qdrant_zendesk/qdrant_zendesk.py index 5416f2f2d0..9b6fbee150 100644 --- a/docs/examples/qdrant_zendesk/qdrant_zendesk.py +++ b/docs/examples/qdrant_zendesk/qdrant_zendesk.py @@ -165,14 +165,13 @@ def get_pages( dataset_name="zendesk_data", ) - # run the dlt pipeline and save info about the load process - load_info = pipeline.run( - # here we use a special function to tell Qdrant which fields to embed - qdrant_adapter( - zendesk_support(), # retrieve tickets data - embed=["subject", "description"], - ) - ) + # here we instantiate the source + source = zendesk_support() + # ...and apply special hints on the ticket resource to tell qdrant which fields to embed + qdrant_adapter(source.tickets_data, embed=["subject", "description"]) + + # run the dlt pipeline and print info about the load process + load_info = pipeline.run(source) print(load_info) @@ -189,7 +188,7 @@ def get_pages( # query Qdrant with prompt: getting tickets info close to "cancellation" response = qdrant_client.query( - "zendesk_data_content", # collection/dataset name with the 'content' suffix -> tickets content table + "zendesk_data_tickets_data", # tickets_data collection query_text="cancel subscription", # prompt to search limit=3, # limit the number of results to the nearest 3 embeddings ) diff --git a/tests/destinations/test_utils.py b/tests/destinations/test_utils.py new file mode 100644 index 0000000000..32fc286830 --- /dev/null +++ b/tests/destinations/test_utils.py @@ -0,0 +1,43 @@ +import dlt +import pytest + +from dlt.destinations.utils import get_resource_for_adapter +from dlt.extract import DltResource + + +def test_get_resource_for_adapter() -> None: + # test on pure data + data = [1, 2, 3] + adapted_resource = get_resource_for_adapter(data) + assert isinstance(adapted_resource, DltResource) + assert list(adapted_resource) == [1, 2, 3] + assert adapted_resource.name == "content" + + # test on resource + @dlt.resource(table_name="my_table") + def some_resource(): + yield [1, 2, 3] + + adapted_resource = get_resource_for_adapter(some_resource) + assert adapted_resource == some_resource + assert adapted_resource.name == "some_resource" + + # test on source with one resource + @dlt.source + def source(): + return [some_resource] + + adapted_resource = get_resource_for_adapter(source()) + assert adapted_resource.table_name == "my_table" + + # test on source with multiple resources + @dlt.resource(table_name="my_table") + def other_resource(): + yield [1, 2, 3] + + @dlt.source + def other_source(): + return [some_resource, other_resource] + + with pytest.raises(ValueError): + get_resource_for_adapter(other_source())