Skip to content

[Data] BlockMetadata.input_files causes excessive object store memory usage when passing datasets to Ray Train workers #61923

@lee1258561

Description

@lee1258561

What happened + What you expected to happen

When passing Ray Datasets (created from millions of parquet files) to TorchTrainer, we observe ~12.87 GB of pinned objects per worker in the Ray object store. This causes excessive disk spillage and memory pressure.

Expected: Dataset handles should be lightweight references that don't bloat when serialized to workers.

Actual: BlockMetadata.input_files stores all source file paths, and these are serialized with the dataset when passed to training workers.

Root Cause Analysis

The Data Flow

read_parquet(paths)  # millions of S3 paths
    → FileBasedDatasource.get_read_tasks()
        → DefaultFileMetadataProvider._get_block_metadata()
            → BlockMetadata(input_files=paths)  # ALL paths stored here
    → DatasetStats(metadata={"Read": [read_task.metadata, ...]})
        → ExecutionPlan(stats)  # stored as _in_stats
            → Dataset._plan._in_stats.metadata  # BLOATED

The Problematic Change

Commit: 384f46cbb80 (June 12, 2024)
PR: #45860 - "Remove in_blocks parameter of ExecutionPlan"

This commit (part of the LazyBlockList removal effort) changed how DatasetStats is created:

Before:

block_list = LazyBlockList(read_tasks, ...)
return Dataset(plan=ExecutionPlan(block_list, block_list.stats(), ...))

After:

stats = DatasetStats(
    metadata={"Read": [read_task.metadata for read_task in read_tasks]},
    parent=None,
)
return Dataset(plan=ExecutionPlan(stats, ...))

The full BlockMetadata objects (including input_files with millions of file paths) are now directly stored in DatasetStats.metadata.

Where input_files is populated

File: python/ray/data/datasource/file_meta_provider.py:156-161

def _get_block_metadata(self, paths: List[str], ...) -> BlockMetadata:
    return BlockMetadata(
        num_rows=num_rows,
        size_bytes=...,
        input_files=paths,  # <-- All source file paths
        exec_stats=None,
    )

Impact

When using Ray Data + Ray Train with large parquet datasets:

  • BlockMetadata.input_files contains millions of S3/file paths
  • The bloated metadata is stored in dataset._plan._in_stats.metadata

The critical issue: Dataset is serialized N times for N workers

When TorchTrainer initializes workers, it serializes the Dataset object and passes it as an argument to each RayTrainWorker.init_train_context() call:

TorchTrainer.fit()
    → For each of N workers:
        → RayTrainWorker.init_train_context.remote(train_run_context, dataset_shard_provider)
            → train_run_context contains: datasets dict (serialized)
            → dataset_shard_provider contains: DataIterator with base dataset (serialized again)

This means:

  • N workers × 2 copies per worker (TrainRunContext + DatasetShardProvider) = 2N serialized copies
  • Each copy includes the full _in_stats.metadata with millions of file paths
  • With 8 workers: 16 copies of the bloated metadata pinned in object store
  • Result: ~12.87 GB of pinned objects per worker observed via ray memory

The serialization overhead scales linearly with both:

  1. Number of input files (millions of paths in input_files)
  2. Number of training workers (each receives full serialized dataset)

Proposed Solutions

Option 1: Don't populate input_files at all (Ray Data scope)

Pros:

  • Simplest fix
  • Benefits all Ray Data use cases
  • input_files appears to be only for debugging/stats, not execution

Cons:

  • Breaking change for users who rely on dataset.input_files() API
  • Loss of provenance tracking

Implementation:

# file_meta_provider.py
return BlockMetadata(
    num_rows=num_rows,
    size_bytes=...,
    input_files=None,  # or []
    exec_stats=None,
)

Option 2: Clear metadata when dataset is passed to Ray Train (Ray Train scope)

Pros:

  • No impact on Ray Data standalone usage
  • Preserves input_files for debugging in non-Train scenarios

Cons:

  • Train-specific workaround, doesn't help other serialization scenarios
  • Requires Train to know about Data internals

Implementation (what we did as a workaround):

# In TorchTrainer or DataConfig
def _clear_dataset_stats_metadata(datasets):
    for key, dataset in datasets.items():
        plan = getattr(dataset, '_plan', None)
        if plan:
            in_stats = getattr(plan, '_in_stats', None)
            if in_stats and hasattr(in_stats, 'metadata'):
                in_stats.metadata = {}

Option 3: Store input_files as object reference (Ray Data scope)

Similar to how FileBasedDatasource stores paths, use ray.put() to store the input_files list in the object store and keep only a reference in BlockMetadata.

Pros:

  • Preserves full input_files information
  • Only one copy in object store, referenced by all workers
  • Consistent with existing patterns in Ray Data

Cons:

  • More complex implementation
  • Requires careful lifecycle management of the object reference
  • May need schema changes to BlockMetadata

Implementation sketch:

# When creating BlockMetadata
input_files_ref = ray.put(paths) if len(paths) > threshold else None

return BlockMetadata(
    num_rows=num_rows,
    size_bytes=...,
    input_files=[] if input_files_ref else paths,
    input_files_ref=input_files_ref,  # New field
    exec_stats=None,
)

Recommendation

Option 1 (don't populate input_files) is the cleanest solution if the field is not critical for execution. The input_files() API could be deprecated or made to return an empty list with a warning.

If backward compatibility is required, Option 3 (object reference) provides a good balance between preserving information and avoiding serialization bloat.

Versions / Dependencies

  • Ray version: 2.52.1
  • Python version: 3.12
  • OS: Ubuntu 24.04

Reproduction script

import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig

ray.init()

# Read from millions of parquet files (e.g., a large production dataset)
# The issue manifests when input_files contains millions of S3 paths
ds = ray.data.read_parquet("s3://bucket/path/with/millions/of/files/")

def train_fn(config):
    pass

trainer = TorchTrainer(
    train_fn,
    datasets={"train": ds},
    scaling_config=ScalingConfig(num_workers=8),
)

# Check object store before fit()
# ray memory --stats-only

trainer.fit()

Run ray memory to observe large pinned objects from RayTrainWorker.init_train_context.

Issue Severity

None

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething that is supposed to be working; but isn'tdataRay Data-related issuestrainRay Train Related IssuetriageNeeds triage (eg: priority, bug/not-bug, and owning component)

    Type

    No fields configured for Bug.

    Projects

    Status
    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions