Skip to content

Commit f1633e5

Browse files
authored
Merge pull request #959 from dlt-hub/devel
dlt 0.4.4 release master merge
2 parents a8f3338 + 882b29b commit f1633e5

File tree

15 files changed

+4445
-3940
lines changed

15 files changed

+4445
-3940
lines changed

dlt/common/libs/pydantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def pydantic_to_table_schema_columns(
140140
# This case is for a single field schema/model
141141
# we need to generate snake_case field names
142142
# and return flattened field schemas
143-
schema_hints = pydantic_to_table_schema_columns(field.annotation)
143+
schema_hints = pydantic_to_table_schema_columns(inner_type)
144144

145145
for field_name, hints in schema_hints.items():
146146
schema_key = snake_case_naming_convention.make_path(name, field_name)

dlt/common/typing.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,26 @@
2525
runtime_checkable,
2626
IO,
2727
)
28-
from typing_extensions import TypeAlias, ParamSpec, Concatenate, Annotated, get_args, get_origin
28+
29+
from typing_extensions import (
30+
Annotated,
31+
Never,
32+
ParamSpec,
33+
TypeAlias,
34+
Concatenate,
35+
get_args,
36+
get_origin,
37+
)
38+
39+
try:
40+
from types import UnionType # type: ignore[attr-defined]
41+
except ImportError:
42+
# Since new Union syntax was introduced in Python 3.10
43+
# we need to substitute it here for older versions.
44+
# it is defined as type(int | str) but for us having it
45+
# as shown here should suffice because it is valid only
46+
# in versions of Python>=3.10.
47+
UnionType = Never
2948

3049
from dlt.common.pendulum import timedelta, pendulum
3150

@@ -103,18 +122,35 @@ def extract_type_if_modifier(t: Type[Any]) -> Type[Any]:
103122

104123

105124
def is_union_type(hint: Type[Any]) -> bool:
106-
if get_origin(hint) is Union:
125+
# We need to handle UnionType because with Python>=3.10
126+
# new Optional syntax was introduced which treats Optionals
127+
# as unions and probably internally there is no additional
128+
# type hints to handle this edge case, see the examples below
129+
# >>> type(str | int)
130+
# <class 'types.UnionType'>
131+
# >>> type(str | None)
132+
# <class 'types.UnionType'>
133+
# type(Union[int, str])
134+
# <class 'typing._GenericAlias'>
135+
origin = get_origin(hint)
136+
if origin is Union or origin is UnionType:
107137
return True
138+
108139
if hint := extract_type_if_modifier(hint):
109140
return is_union_type(hint)
141+
110142
return False
111143

112144

113145
def is_optional_type(t: Type[Any]) -> bool:
114-
if get_origin(t) is Union:
115-
return type(None) in get_args(t)
146+
origin = get_origin(t)
147+
is_union = origin is Union or origin is UnionType
148+
if is_union and type(None) in get_args(t):
149+
return True
150+
116151
if t := extract_type_if_modifier(t):
117152
return is_optional_type(t)
153+
118154
return False
119155

120156

@@ -232,7 +268,7 @@ def get_generic_type_argument_from_instance(
232268

233269

234270
def copy_sig(
235-
wrapper: Callable[TInputArgs, Any]
271+
wrapper: Callable[TInputArgs, Any],
236272
) -> Callable[[Callable[..., TReturnVal]], Callable[TInputArgs, TReturnVal]]:
237273
"""Copies docstring and signature from wrapper to func but keeps the func return value type"""
238274

dlt/destinations/impl/athena/athena.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,9 @@ def _from_db_type(
351351
return self.type_mapper.from_db_type(hive_t, precision, scale)
352352

353353
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
354-
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}"
354+
return (
355+
f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}"
356+
)
355357

356358
def _get_table_update_sql(
357359
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
@@ -376,19 +378,15 @@ def _get_table_update_sql(
376378
# use qualified table names
377379
qualified_table_name = self.sql_client.make_qualified_ddl_table_name(table_name)
378380
if is_iceberg and not generate_alter:
379-
sql.append(
380-
f"""CREATE TABLE {qualified_table_name}
381+
sql.append(f"""CREATE TABLE {qualified_table_name}
381382
({columns})
382383
LOCATION '{location}'
383-
TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');"""
384-
)
384+
TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""")
385385
elif not generate_alter:
386-
sql.append(
387-
f"""CREATE EXTERNAL TABLE {qualified_table_name}
386+
sql.append(f"""CREATE EXTERNAL TABLE {qualified_table_name}
388387
({columns})
389388
STORED AS PARQUET
390-
LOCATION '{location}';"""
391-
)
389+
LOCATION '{location}';""")
392390
# alter table to add new columns at the end
393391
else:
394392
sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""")

dlt/destinations/impl/bigquery/bigquery.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,9 @@ def _get_table_update_sql(
252252
elif (c := partition_list[0])["data_type"] == "date":
253253
sql[0] = f"{sql[0]}\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}"
254254
elif (c := partition_list[0])["data_type"] == "timestamp":
255-
sql[
256-
0
257-
] = f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})"
255+
sql[0] = (
256+
f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})"
257+
)
258258
# Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp.
259259
# This is due to the bounds requirement of GENERATE_ARRAY function for partitioning.
260260
# The 10,000 partitions limit makes it infeasible to cover the entire `bigint` range.
@@ -272,7 +272,9 @@ def _get_table_update_sql(
272272

273273
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
274274
name = self.capabilities.escape_identifier(c["name"])
275-
return f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}"
275+
return (
276+
f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}"
277+
)
276278

277279
def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]:
278280
schema_table: TTableSchemaColumns = {}

dlt/destinations/impl/databricks/databricks.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,14 @@ def __init__(
166166
else:
167167
raise LoadJobTerminalException(
168168
file_path,
169-
f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and azure buckets are supported",
169+
f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and"
170+
" azure buckets are supported",
170171
)
171172
else:
172173
raise LoadJobTerminalException(
173174
file_path,
174-
"Cannot load from local file. Databricks does not support loading from local files. Configure staging with an s3 or azure storage bucket.",
175+
"Cannot load from local file. Databricks does not support loading from local files."
176+
" Configure staging with an s3 or azure storage bucket.",
175177
)
176178

177179
# decide on source format, stage_file_path will either be a local file or a bucket path
@@ -181,27 +183,33 @@ def __init__(
181183
if not config.get("data_writer.disable_compression"):
182184
raise LoadJobTerminalException(
183185
file_path,
184-
"Databricks loader does not support gzip compressed JSON files. Please disable compression in the data writer configuration: https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression",
186+
"Databricks loader does not support gzip compressed JSON files. Please disable"
187+
" compression in the data writer configuration:"
188+
" https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression",
185189
)
186190
if table_schema_has_type(table, "decimal"):
187191
raise LoadJobTerminalException(
188192
file_path,
189-
"Databricks loader cannot load DECIMAL type columns from json files. Switch to parquet format to load decimals.",
193+
"Databricks loader cannot load DECIMAL type columns from json files. Switch to"
194+
" parquet format to load decimals.",
190195
)
191196
if table_schema_has_type(table, "binary"):
192197
raise LoadJobTerminalException(
193198
file_path,
194-
"Databricks loader cannot load BINARY type columns from json files. Switch to parquet format to load byte values.",
199+
"Databricks loader cannot load BINARY type columns from json files. Switch to"
200+
" parquet format to load byte values.",
195201
)
196202
if table_schema_has_type(table, "complex"):
197203
raise LoadJobTerminalException(
198204
file_path,
199-
"Databricks loader cannot load complex columns (lists and dicts) from json files. Switch to parquet format to load complex types.",
205+
"Databricks loader cannot load complex columns (lists and dicts) from json"
206+
" files. Switch to parquet format to load complex types.",
200207
)
201208
if table_schema_has_type(table, "date"):
202209
raise LoadJobTerminalException(
203210
file_path,
204-
"Databricks loader cannot load DATE type columns from json files. Switch to parquet format to load dates.",
211+
"Databricks loader cannot load DATE type columns from json files. Switch to"
212+
" parquet format to load dates.",
205213
)
206214

207215
source_format = "JSON"
@@ -311,7 +319,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non
311319

312320
def _get_storage_table_query_columns(self) -> List[str]:
313321
fields = super()._get_storage_table_query_columns()
314-
fields[
315-
1
316-
] = "full_data_type" # Override because this is the only way to get data type with precision
322+
fields[1] = ( # Override because this is the only way to get data type with precision
323+
"full_data_type"
324+
)
317325
return fields

dlt/destinations/impl/snowflake/snowflake.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,13 @@ def __init__(
175175
f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,'
176176
" AUTO_COMPRESS = FALSE"
177177
)
178-
client.execute_sql(
179-
f"""COPY INTO {qualified_table_name}
178+
client.execute_sql(f"""COPY INTO {qualified_table_name}
180179
{from_clause}
181180
{files_clause}
182181
{credentials_clause}
183182
FILE_FORMAT = {source_format}
184183
MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'
185-
"""
186-
)
184+
""")
187185
if stage_file_path and not keep_staged_files:
188186
client.execute_sql(f"REMOVE {stage_file_path}")
189187

dlt/extract/incremental/__init__.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __init__(
119119
self.start_value: Any = initial_value
120120
"""Value of last_value at the beginning of current pipeline run"""
121121
self.resource_name: Optional[str] = None
122-
self.primary_key: Optional[TTableHintTemplate[TColumnNames]] = primary_key
122+
self._primary_key: Optional[TTableHintTemplate[TColumnNames]] = primary_key
123123
self.allow_external_schedulers = allow_external_schedulers
124124

125125
self._cached_state: IncrementalColumnState = None
@@ -133,6 +133,18 @@ def __init__(
133133

134134
self._transformers: Dict[str, IncrementalTransform] = {}
135135

136+
@property
137+
def primary_key(self) -> Optional[TTableHintTemplate[TColumnNames]]:
138+
return self._primary_key
139+
140+
@primary_key.setter
141+
def primary_key(self, value: str) -> None:
142+
# set key in incremental and data type transformers
143+
self._primary_key = value
144+
if self._transformers:
145+
for transform in self._transformers.values():
146+
transform.primary_key = value
147+
136148
def _make_transforms(self) -> None:
137149
types = [("arrow", ArrowIncremental), ("json", JsonIncremental)]
138150
for dt, kls in types:
@@ -143,7 +155,7 @@ def _make_transforms(self) -> None:
143155
self.end_value,
144156
self._cached_state,
145157
self.last_value_func,
146-
self.primary_key,
158+
self._primary_key,
147159
)
148160

149161
@classmethod
@@ -163,7 +175,7 @@ def copy(self) -> "Incremental[TCursorValue]":
163175
self.cursor_path,
164176
initial_value=self.initial_value,
165177
last_value_func=self.last_value_func,
166-
primary_key=self.primary_key,
178+
primary_key=self._primary_key,
167179
end_value=self.end_value,
168180
allow_external_schedulers=self.allow_external_schedulers,
169181
)
@@ -178,7 +190,7 @@ def merge(self, other: "Incremental[TCursorValue]") -> "Incremental[TCursorValue
178190
>>>
179191
>>> my_resource(updated=incremental(initial_value='2023-01-01', end_value='2023-02-01'))
180192
"""
181-
kwargs = dict(self, last_value_func=self.last_value_func, primary_key=self.primary_key)
193+
kwargs = dict(self, last_value_func=self.last_value_func, primary_key=self._primary_key)
182194
for key, value in dict(
183195
other, last_value_func=other.last_value_func, primary_key=other.primary_key
184196
).items():
@@ -395,7 +407,6 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]:
395407
return rows
396408

397409
transformer = self._get_transformer(rows)
398-
transformer.primary_key = self.primary_key
399410

400411
if isinstance(rows, list):
401412
return [
@@ -476,7 +487,7 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
476487
elif isinstance(p.default, Incremental):
477488
new_incremental = p.default.copy()
478489

479-
if not new_incremental or new_incremental.is_partial():
490+
if (not new_incremental or new_incremental.is_partial()) and not self._incremental:
480491
if is_optional_type(p.annotation):
481492
bound_args.arguments[p.name] = None # Remove partial spec
482493
return func(*bound_args.args, **bound_args.kwargs)
@@ -486,15 +497,16 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
486497
)
487498
# pass Generic information from annotation to new_incremental
488499
if (
489-
not hasattr(new_incremental, "__orig_class__")
500+
new_incremental
501+
and not hasattr(new_incremental, "__orig_class__")
490502
and p.annotation
491503
and get_args(p.annotation)
492504
):
493505
new_incremental.__orig_class__ = p.annotation # type: ignore
494506

495507
# set the incremental only if not yet set or if it was passed explicitly
496508
# NOTE: the _incremental may be also set by applying hints to the resource see `set_template` in `DltResource`
497-
if p.name in bound_args.arguments or not self._incremental:
509+
if (new_incremental and p.name in bound_args.arguments) or not self._incremental:
498510
self._incremental = new_incremental
499511
self._incremental.resolve()
500512
# in case of transformers the bind will be called before this wrapper is set: because transformer is called for a first time late in the pipe
@@ -531,6 +543,9 @@ def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]:
531543
return item
532544
if self._incremental.primary_key is None:
533545
self._incremental.primary_key = self.primary_key
546+
elif self.primary_key is None:
547+
# propagate from incremental
548+
self.primary_key = self._incremental.primary_key
534549
return self._incremental(item, meta)
535550

536551

dlt/pipeline/pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,9 +1163,9 @@ def _set_context(self, is_active: bool) -> None:
11631163
# set destination context on activation
11641164
if self.destination:
11651165
# inject capabilities context
1166-
self._container[
1167-
DestinationCapabilitiesContext
1168-
] = self._get_destination_capabilities()
1166+
self._container[DestinationCapabilitiesContext] = (
1167+
self._get_destination_capabilities()
1168+
)
11691169
else:
11701170
# remove destination context on deactivation
11711171
if DestinationCapabilitiesContext in self._container:

0 commit comments

Comments
 (0)