Skip to content

Commit 89b8da5

Browse files
authored
fixes bug, where configs where not injected for async functions (#1241)
1 parent 77e2499 commit 89b8da5

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

dlt/common/reflection/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import ast
22
import inspect
33
import astunparse
4-
from typing import Any, Dict, List, Optional, Sequence, Tuple
4+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
55

66
from dlt.common.typing import AnyFun
77

88

9-
def get_literal_defaults(node: ast.FunctionDef) -> Dict[str, str]:
9+
def get_literal_defaults(node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> Dict[str, str]:
1010
"""Extract defaults from function definition node literally, as pieces of source code"""
1111
defaults: List[ast.expr] = []
1212
if node.args.defaults:
@@ -30,12 +30,12 @@ def get_literal_defaults(node: ast.FunctionDef) -> Dict[str, str]:
3030
return literal_defaults
3131

3232

33-
def get_func_def_node(f: AnyFun) -> ast.FunctionDef:
33+
def get_func_def_node(f: AnyFun) -> Union[ast.FunctionDef, ast.AsyncFunctionDef]:
3434
"""Finds the function definition node for function f by parsing the source code of the f's module"""
3535
source, lineno = inspect.findsource(inspect.unwrap(f))
3636

3737
for node in ast.walk(ast.parse("".join(source))):
38-
if isinstance(node, ast.FunctionDef):
38+
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
3939
f_lineno = node.lineno - 1
4040
# get line number of first decorator
4141
if node.decorator_list:

tests/common/reflection/test_reflect_spec.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,23 @@ def _f_4(str_str=300, p_def: bool = True):
358358
assert SPEC.get_resolvable_fields()["str_str"] == int
359359
# default
360360
assert fields["str_str"] == 300
361+
362+
363+
def test_reflect_async_function() -> None:
364+
async def _f_1_as(str_str: str = dlt.config.value, blah: bool = dlt.config.value):
365+
import asyncio
366+
367+
await asyncio.sleep(1)
368+
369+
SPEC_AS, fields_as = spec_from_signature(_f_1_as, inspect.signature(_f_1_as), False)
370+
371+
def _f_1(str_str: str = dlt.config.value, blah: bool = dlt.config.value):
372+
pass
373+
374+
SPEC, fields = spec_from_signature(_f_1, inspect.signature(_f_1), False)
375+
376+
# discovered fields are the same for sync and async functions
377+
assert fields
378+
assert fields == fields_as
379+
assert len(SPEC.get_resolvable_fields()) == len(fields) == 2
380+
assert SPEC.get_resolvable_fields() == SPEC_AS.get_resolvable_fields()

0 commit comments

Comments
 (0)