Skip to content
7 changes: 6 additions & 1 deletion dlt/common/runners/pool_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dlt.common import logger
from dlt.common.configuration.container import Container
from dlt.common.configuration.specs.pluggable_run_context import PluggableRunContext
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.runtime import init
from dlt.common.runners.runnable import Runnable, TExecutor
from dlt.common.runners.configuration import PoolRunnerConfiguration
Expand Down Expand Up @@ -147,10 +148,14 @@ def create_pool(config: PoolRunnerConfiguration) -> Executor:
)
if start_method != "fork":
ctx = Container()[PluggableRunContext]
section_ctx = None
if ConfigSectionContext in Container():
section_ctx = Container()[ConfigSectionContext]

executor = ProcessPoolExecutor(
max_workers=config.workers,
initializer=init.restore_run_context,
initargs=(ctx.context,),
initargs=(ctx.context, section_ctx),
mp_context=multiprocessing.get_context(method=start_method),
)
else:
Expand Down
14 changes: 12 additions & 2 deletions dlt/common/runtime/init.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional

from dlt.common.configuration.specs import RuntimeConfiguration
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.configuration.specs.pluggable_run_context import (
PluggableRunContext,
RunContextBase,
Expand Down Expand Up @@ -29,8 +32,12 @@ def initialize_runtime(logger_name: str, runtime_config: RuntimeConfiguration) -
start_telemetry(runtime_config)


def restore_run_context(run_context: RunContextBase) -> None:
"""Restores `run_context` by placing it into container and if `runtime_config` is present, initializes runtime
def restore_run_context(
run_context: RunContextBase, section_context: Optional[ConfigSectionContext] = None
) -> None:
"""Restores `run_context` and optionally `section_context` by placing them into container.
If `runtime_config` is present, initializes runtime.

Intended to be called by workers in a process pool.
"""
from dlt.common.configuration.container import Container
Expand All @@ -39,3 +46,6 @@ def restore_run_context(run_context: RunContextBase) -> None:
assert run_context.runtime_config is not None

Container()[PluggableRunContext] = PluggableRunContext(run_context)

if section_context:
Container()[ConfigSectionContext] = section_context
143 changes: 143 additions & 0 deletions tests/common/runtime/test_config_section_context_restore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Test that ConfigSectionContext is properly restored in spawned worker processes."""
Copy link
Contributor Author

@djudjuu djudjuu Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want this test? I wanted to have something that is closer to the actual change, whereas the other test is more of a regression test that fixes the exact issue


import os
from typing import Tuple, ClassVar

import pytest

from dlt.common.configuration.container import Container
from dlt.common.configuration.specs import PluggableRunContext, ConfigSectionContext
from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec
from dlt.common.configuration.resolve import resolve_configuration
from dlt.common.runners.configuration import PoolRunnerConfiguration
from dlt.common.runners.pool_runner import create_pool


@configspec
class SectionedTestConfig(BaseConfiguration):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to have those tests. pls. move them to test_runners.py. (tbh. worker init tests were missing altogether)

"""A test configuration that uses a specific section."""

test_value: str = "default"

__section__: ClassVar[str] = "test_section"


def _worker_resolve_config() -> Tuple[str, Tuple[str, ...]]:
"""Worker function that resolves a config value using ConfigSectionContext.
Returns:
Tuple of (resolved_value, sections_from_context)
"""
section_ctx = Container()[ConfigSectionContext]
config = resolve_configuration(SectionedTestConfig())

return config.test_value, section_ctx.sections


def test_config_section_context_restored_in_spawn_worker() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually sections should be visible both in case of spawn and fork. so test both (parametrize)

"""Test that ConfigSectionContext is properly restored when using spawn method.
This test verifies that ConfigSectionContext is correctly serialized and restored
in worker processes, allowing config resolution to use the correct sections.
"""
# Set up environment variable with section-specific value
os.environ["MY_SECTION__TEST_SECTION__TEST_VALUE"] = "sectioned_value"
os.environ["TEST_SECTION__TEST_VALUE"] = "non_sectioned_value" # Should not be used

# Set up ConfigSectionContext in main process
section_context = ConfigSectionContext(
pipeline_name=None,
sections=("my_section",),
)

# Store it in container
container = Container()
container[ConfigSectionContext] = section_context
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are also missing the case when there's not section context in container (do del on container[ConfigSectionContext]`)
you can parametrize this test to chekc that


# Create process pool with spawn method and multiple workers
# Using multiple workers ensures we're actually testing cross-process behavior
config = PoolRunnerConfiguration(
pool_type="process",
workers=4,
start_method="spawn",
)

with create_pool(config) as pool:
# Submit multiple tasks to ensure we're using worker processes
futures = [pool.submit(_worker_resolve_config) for _ in range(4)]
results = [f.result() for f in futures]

# All workers should have the same ConfigSectionContext
result_value, result_sections = results[0]

# Verify that ConfigSectionContext was restored correctly
assert result_sections == ("my_section",), (
f"Expected sections ('my_section',) but got {result_sections}. "
"ConfigSectionContext was not properly restored in worker process."
)

# Verify that config resolution used the correct sections
assert result_value == "sectioned_value", (
f"Expected 'sectioned_value' but got '{result_value}'. "
"Config resolution did not use the restored ConfigSectionContext sections."
)


def test_config_section_context_with_pipeline_name() -> None:
pipeline_name = "test_pipeline"
os.environ[f"{pipeline_name.upper()}__MY_SECTION__TEST_SECTION__TEST_VALUE"] = (
"pipeline_sectioned_value"
)
os.environ["MY_SECTION__TEST_SECTION__TEST_VALUE"] = "sectioned_value"

section_context = ConfigSectionContext(
pipeline_name=pipeline_name,
sections=("my_section",),
)

container = Container()
container[ConfigSectionContext] = section_context

config = PoolRunnerConfiguration(
pool_type="process",
workers=4,
start_method="spawn",
)

with create_pool(config) as pool:
futures = [pool.submit(_worker_resolve_config) for _ in range(4)]
results = [f.result() for f in futures]

# Verify all workers got the same context
for result_value, result_sections in results:
assert result_sections == ("my_section",)
# Should prefer pipeline-specific value
assert result_value == "pipeline_sectioned_value"


def test_config_section_context_empty_sections() -> None:
os.environ["TEST_SECTION__TEST_VALUE"] = "non_sectioned_value"

# ConfigSectionContext with empty sections
section_context = ConfigSectionContext(
pipeline_name=None,
sections=(),
)

container = Container()
container[ConfigSectionContext] = section_context

config = PoolRunnerConfiguration(
pool_type="process",
workers=4,
start_method="spawn",
)

with create_pool(config) as pool:
futures = [pool.submit(_worker_resolve_config) for _ in range(4)]
results = [f.result() for f in futures]

# Verify all workers got empty sections
for result_value, result_sections in results:
assert result_sections == (), "Empty sections should be preserved"
assert result_value == "non_sectioned_value", "Should use non-sectioned value"
54 changes: 54 additions & 0 deletions tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,60 @@ def test_pipeline_data_writer_compression(
assert_table_column(p, "data", data, info=info)


@pytest.mark.parametrize(
"destination_config",
destinations_configs(
default_sql_configs=True,
all_buckets_filesystem_configs=True,
),
ids=lambda x: x.name,
)
def test_normalize_compression_with_spawn_workers(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool test! but running for all destinations is an overkill. pls. move it to tests/pipeline, it does need load step.

destination_config: DestinationTestConfiguration,
) -> None:
"""Disabling compression should work with multiple workers and spawn method,
because ConfigSectionContext is restored in worker processes.
"""
# Set compression disabled via normalize section
workers = 4
os.environ["NORMALIZE__DATA_WRITER__DISABLE_COMPRESSION"] = "true"
os.environ["NORMALIZE__WORKERS"] = str(workers)
os.environ["NORMALIZE__START_METHOD"] = "spawn"

data = ["a", "b", "c", "d", "e"]
dataset_name = "compression_spawn_test_" + uniq_id()

p = destination_config.setup_pipeline("compression_spawn_test", dataset_name=dataset_name)
p.extract(
dlt.resource(data, name="data"),
table_format=destination_config.table_format,
loader_file_format=destination_config.file_format,
)

# Normalize with multiple workers and spawn method
p.normalize(workers=workers)

# Check that normalized files are not compressed
load_storage = p._get_load_storage()
normalized_packages = load_storage.list_normalized_packages()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method is already on the pipeline

assert len(normalized_packages) > 0, "Should have at least one normalized package"

for load_id in normalized_packages:
# Get all job files from the normalized package
job_files = load_storage.normalized_packages.list_new_jobs(load_id)
assert len(job_files) > 0, f"Should have at least one job file in package {load_id}"

for job_file_name in job_files:
file_path = load_storage.normalized_packages.storage.make_full_path(job_file_name)
# If compression is disabled, file should NOT be gzipped
with pytest.raises(gzip.BadGzipFile):
with gzip.open(file_path, "rb") as f:
f.read()

info = p.load()
assert_table_column(p, "data", data, info=info)


@pytest.mark.essential
@pytest.mark.parametrize(
"destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name
Expand Down
Loading