Skip to content
7 changes: 6 additions & 1 deletion dlt/common/runners/pool_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dlt.common import logger
from dlt.common.configuration.container import Container
from dlt.common.configuration.specs.pluggable_run_context import PluggableRunContext
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.runtime import init
from dlt.common.runners.runnable import Runnable, TExecutor
from dlt.common.runners.configuration import PoolRunnerConfiguration
Expand Down Expand Up @@ -147,10 +148,14 @@ def create_pool(config: PoolRunnerConfiguration) -> Executor:
)
if start_method != "fork":
ctx = Container()[PluggableRunContext]
section_ctx = None
if ConfigSectionContext in Container():
section_ctx = Container()[ConfigSectionContext]

executor = ProcessPoolExecutor(
max_workers=config.workers,
initializer=init.restore_run_context,
initargs=(ctx.context,),
initargs=(ctx.context, section_ctx),
mp_context=multiprocessing.get_context(method=start_method),
)
else:
Expand Down
14 changes: 12 additions & 2 deletions dlt/common/runtime/init.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional

from dlt.common.configuration.specs import RuntimeConfiguration
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.configuration.specs.pluggable_run_context import (
PluggableRunContext,
RunContextBase,
Expand Down Expand Up @@ -29,8 +32,12 @@ def initialize_runtime(logger_name: str, runtime_config: RuntimeConfiguration) -
start_telemetry(runtime_config)


def restore_run_context(run_context: RunContextBase) -> None:
"""Restores `run_context` by placing it into container and if `runtime_config` is present, initializes runtime
def restore_run_context(
run_context: RunContextBase, section_context: Optional[ConfigSectionContext] = None
) -> None:
"""Restores `run_context` and optionally `section_context` by placing them into container.
If `runtime_config` is present, initializes runtime.

Intended to be called by workers in a process pool.
"""
from dlt.common.configuration.container import Container
Expand All @@ -39,3 +46,6 @@ def restore_run_context(run_context: RunContextBase) -> None:
assert run_context.runtime_config is not None

Container()[PluggableRunContext] = PluggableRunContext(run_context)

if section_context:
Container()[ConfigSectionContext] = section_context
100 changes: 98 additions & 2 deletions tests/common/runners/test_runners.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import pytest
import sys
import time
import multiprocessing
from typing import Type
from typing import ClassVar, Tuple, Type

from dlt.common.runtime import signals
from dlt.common.configuration import resolve_configuration, configspec
from dlt.common.configuration.specs import RuntimeConfiguration
from dlt.common.configuration.container import Container
from dlt.common.configuration.specs import ConfigSectionContext, RuntimeConfiguration
from dlt.common.configuration.specs.base_configuration import BaseConfiguration
from dlt.common.exceptions import DltException, SignalReceivedException
from dlt.common.runners import pool_runner as runner
from dlt.common.runners.configuration import PoolRunnerConfiguration, TPoolType
Expand Down Expand Up @@ -38,6 +41,27 @@ class ThreadPoolConfiguration(ModPoolRunnerConfiguration):
pool_type: TPoolType = "thread"


@configspec
class SectionedTestConfig(BaseConfiguration):
"""A test configuration that uses a specific section."""

test_value: str = "default"

__section__: ClassVar[str] = "test_section"


def _worker_resolve_config() -> Tuple[str, Tuple[str, ...]]:
"""Worker function that resolves a config value using ConfigSectionContext.

Returns:
Tuple of (resolved_value, sections_from_context)
"""
section_ctx = Container()[ConfigSectionContext]
config = resolve_configuration(SectionedTestConfig())

return config.test_value, section_ctx.sections


def configure(C: Type[PoolRunnerConfiguration]) -> PoolRunnerConfiguration:
default = C()
return resolve_configuration(default)
Expand Down Expand Up @@ -233,3 +257,75 @@ def test_use_null_executor_on_non_threading_platform(monkeypatch) -> None:
config.pool_type = None
pool = runner.create_pool(config)
assert isinstance(pool, runner.NullExecutor)


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
@pytest.mark.parametrize(
"use_section_context",
[True, False],
ids=lambda x: "with_section_context" if x else "without_section_context",
)
def test_config_section_context_restored_in_worker(
start_method: str, use_section_context: bool
) -> None:
"""Test that ConfigSectionContext is properly restored in worker processes.

This test verifies that ConfigSectionContext is correctly serialized and restored
in worker processes, allowing config resolution to use the correct sections.
When no ConfigSectionContext is set, workers should use the default empty sections.
"""
# Set up environment variables with section-specific values
os.environ["MY_SECTION__TEST_SECTION__TEST_VALUE"] = "sectioned_value"
os.environ["TEST_SECTION__TEST_VALUE"] = "non_sectioned_value"

container = Container()

if use_section_context:
# Set up ConfigSectionContext in main process
section_context = ConfigSectionContext(
pipeline_name=None,
sections=("my_section",),
)
container[ConfigSectionContext] = section_context
elif ConfigSectionContext in container:
# Ensure no ConfigSectionContext is in container
del container[ConfigSectionContext]

# Create process pool with multiple workers
# Using multiple workers ensures we're actually testing cross-process behavior
config = PoolRunnerConfiguration(
pool_type="process",
workers=4,
start_method=start_method,
)

with runner.create_pool(config) as pool:
# Submit multiple tasks to ensure we're using worker processes
futures = [pool.submit(_worker_resolve_config) for _ in range(4)]
results = [f.result() for f in futures]

# All workers should have the same ConfigSectionContext
result_value, result_sections = results[0]

if use_section_context:
# Verify that ConfigSectionContext was restored correctly
assert result_sections == ("my_section",), (
f"Expected sections ('my_section',) but got {result_sections}. "
"ConfigSectionContext was not properly restored in worker process."
)
# Verify that config resolution used the correct sections
assert result_value == "sectioned_value", (
f"Expected 'sectioned_value' but got '{result_value}'. "
"Config resolution did not use the restored ConfigSectionContext sections."
)
else:
# Without section context, should use default empty sections
assert result_sections == (), (
f"Expected empty sections () but got {result_sections}. "
"ConfigSectionContext should have default empty sections when not set."
)
# Verify that config resolution used the non-sectioned value
assert result_value == "non_sectioned_value", (
f"Expected 'non_sectioned_value' but got '{result_value}'. "
"Config resolution should use non-sectioned value when no ConfigSectionContext is set."
)
46 changes: 46 additions & 0 deletions tests/pipeline/test_parallelism.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""
Actual parallelism test with the help of custom destination
"""
import gzip
import os

import pytest
import dlt
import time
from typing import Dict, Tuple

from dlt.common.typing import TDataItems
from dlt.common.schema import TTableSchema
from dlt.common.destination.capabilities import TLoaderParallelismStrategy
from tests.pipeline.utils import (
assert_table_column,
)


def run_pipeline(
Expand Down Expand Up @@ -113,3 +119,43 @@ def test_loading_strategy() -> None:
"t2": 1,
"t3": 1,
}


def test_normalize_compression_with_spawn_workers() -> None:
"""Disabling compression should work with multiple workers and spawn method,
because ConfigSectionContext is restored in worker processes.
"""
# Set compression disabled via normalize section
workers = 4
os.environ["NORMALIZE__DATA_WRITER__DISABLE_COMPRESSION"] = "true"
os.environ["NORMALIZE__WORKERS"] = str(workers)
os.environ["NORMALIZE__START_METHOD"] = "spawn"

data = ["a", "b", "c", "d", "e"]
dataset_name = "compression_spawn_test_"

p = dlt.pipeline("compression_spawn_test", dataset_name=dataset_name, destination="duckdb")
p.extract(dlt.resource(data, name="data"))

# Normalize with multiple workers and spawn method
p.normalize(workers=workers)

# Check that normalized files are not compressed
normalized_packages = p.list_normalized_load_packages()
assert len(normalized_packages) > 0, "Should have at least one normalized package"

job_storage = p._get_load_storage()
for load_id in normalized_packages:
# Get all job files from the normalized package
job_files = job_storage.normalized_packages.list_new_jobs(load_id)
assert len(job_files) > 0, f"Should have at least one job file in package {load_id}"

for job_file_name in job_files:
file_path = job_storage.normalized_packages.storage.make_full_path(job_file_name)
# If compression is disabled, file should NOT be gzipped
with pytest.raises(gzip.BadGzipFile):
with gzip.open(file_path, "rb") as f:
f.read()

info = p.load()
assert_table_column(p, "data", data, info=info)
Loading