Skip to content

Add FlowM BlockWorld memory benchmark dataset support#1

Open
huskydoge wants to merge 1 commit intosimchowitzlabpublic:mainfrom
huskydoge:codex/flowm-blockworld-dataset
Open

Add FlowM BlockWorld memory benchmark dataset support#1
huskydoge wants to merge 1 commit intosimchowitzlabpublic:mainfrom
huskydoge:codex/flowm-blockworld-dataset

Conversation

@huskydoge
Copy link
Copy Markdown

@huskydoge huskydoge commented May 3, 2026

Summary

This PR adds dataset support for FlowM's 3D Dynamic BlockWorld memory benchmark through nano-world-model's existing dataset/data-source interface.

Original code repo: https://github.com/hlillemark/flowm
Paper link: https://arxiv.org/abs/2601.01075

The integration is intentionally minimal:

  • adds a BlockWorldDataSource under wm_datasets.data_source.memory
  • registers dataset=memory/blockworld
  • exposes FlowM discrete actions as 5-D one-hot action vectors
  • loads RGB MP4 frames through the existing WorldModelDataset video/action path
  • documents the expected data layout and training command

Depth videos are left untouched for now; this PR only wires the RGB/action path needed by NanoWM training.

Dataset Layout

Expected layout:

${DATASET_DIR}/blockworld/
├── sunday_v2_training/
│   └── 0/
│       ├── 0000_rgb.mp4
│       ├── 0000_depth.mp4
│       └── 0000_actions.pt
└── sunday_v2_validation/

*_actions.pt may contain either integer action ids or already-vectorized [T, 5] actions.

Usage

python src/main.py dataset=memory/blockworld model=nanowm_b2

Validation

I validated the integration with synthetic BlockWorld-style episodes matching the expected RGB MP4 + action sidecar layout.

Dataset and config smoke

$ python -m unittest tests.test_blockworld_data_source
....
----------------------------------------------------------------------
Ran 4 tests in 0.813s

OK
$ python -m py_compile \
  src/wm_datasets/data_source/memory/blockworld_data_source.py \
  src/wm_datasets/data_source/factory.py \
  src/wm_datasets/data_source/__init__.py \
  src/wm_datasets/world_model_dataset.py
$ python -c "... compose dataset=memory/blockworld model=nanowm_b2 ..."
blockworld
./data/blockworld/sunday_v2_training
./data/blockworld/sunday_v2_validation
5

Model loss smoke

Built NanoWM with the BlockWorld dataset config and action_dim=5, then ran the diffusion training loss and backward pass.

model_arch NanoWM-S/8
action_dim 5
loss 0.9845133423805237
backward_ok True

Full training smoke

Ran a one-step training smoke through the repo's normal src/main.py entrypoint with synthetic BlockWorld-style data.

Command shape:

python src/main.py \
  dataset=memory/blockworld \
  model=nanowm_s2 \
  model.num_frames=2 \
  model.image_size=32 \
  model.latent_size=8 \
  experiment.training.max_steps=1 \
  experiment.training.batch_size=1 \
  experiment.infra.num_workers=0 \
  experiment.infra.compile=false \
  experiment.infra.mixed_precision=false \
  experiment.evaluation.metrics.evaluate=false \
  experiment.evaluation.save_videos=false \
  wandb.enabled=false

Relevant output:

Loaded 2 BlockWorld episodes
  Action dim: 5, State dim: 0 (pure vision)
Split 'train': 2 trajectories (random sampling mode)

Loaded 1 BlockWorld episodes
  Action dim: 5, State dim: 0 (pure vision)
Creating slices (exhaustive mode, stride=1) from 1 trajectories...
Created 5 slices total
Split 'val': 5 slices

[Data] Datasets created
[Init] NanoWMTrainingModule: building model
[Init] VAE loaded, scaling_factor=0.18215, vae_precision=fp32
[Init] VAE sanity: clean under vae_precision=fp32
[Init] NanoWMTrainingModule initialized
[Data] DataLoaders created

(step=0000001/epoch=0000) Train Loss: 0.2491, Gradient Norm: 0.0000
val_loss=0.511

Also ran:

$ git diff --check

No whitespace errors were reported.

Out-of-tree validation snippet

The following focused validation file was used to check the integration with synthetic BlockWorld-style episodes. It is not included in this PR.

Validation file used locally
"""
Tests for the FlowM BlockWorld memory benchmark data source.
"""

import sys
import types
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import patch

import torch
import imageio.v2 as imageio
import numpy as np

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))

if "decord" not in sys.modules:
    decord_stub = types.SimpleNamespace(
        VideoReader=object,
        bridge=types.SimpleNamespace(set_bridge=lambda *_args, **_kwargs: None),
    )
    sys.modules["decord"] = decord_stub
if "h5py" not in sys.modules:
    sys.modules["h5py"] = types.SimpleNamespace(File=object)

from wm_datasets import create_train_val_datasets
from wm_datasets.data_source import create_data_source


class BlockWorldDataSourceTest(unittest.TestCase):
    """Verify that FlowM BlockWorld episodes match the nano-world-model interface."""

    def test_integer_actions_are_exposed_as_one_hot_actions(self) -> None:
        """Load FlowM-style integer actions as 5-D one-hot action vectors."""
        with TemporaryDirectory() as tmpdir:
            self._write_episode(Path(tmpdir) / "sunday_v2_training" / "0", [0, 1, 2, 4])

            source = create_data_source("blockworld", data_path=tmpdir)
            trajectory = source.load_trajectory(0)

            self.assertEqual(source.get_num_trajectories(), 1)
            self.assertEqual(source.action_dim, 5)
            self.assertEqual(source.state_dim, 0)
            self.assertEqual(trajectory.seq_length, 4)
            self.assertEqual(tuple(trajectory.actions.shape), (4, 5))
            self.assertTrue(torch.equal(trajectory.actions[0], torch.tensor([1, 0, 0, 0, 0]).float()))
            self.assertTrue(torch.equal(trajectory.actions[3], torch.tensor([0, 0, 0, 0, 1]).float()))
            self.assertEqual(trajectory.meta["episode_id"], "sunday_v2_training/0/0000_rgb.mp4")

    def test_train_val_factory_accepts_split_paths(self) -> None:
        """Create WorldModelDataset train/val splits from FlowM split directories."""
        with TemporaryDirectory() as tmpdir:
            root = Path(tmpdir) / "blockworld"
            self._write_episode(root / "sunday_v2_training" / "0", [0, 1, 2, 4])
            self._write_episode(root / "sunday_v2_validation" / "0", [4, 2, 1, 0])

            train_dataset, val_dataset = create_train_val_datasets(
                dataset_name="blockworld",
                data_path_train=str(root / "sunday_v2_training"),
                data_path_val=str(root / "sunday_v2_validation"),
                num_frames=2,
                normalize_action=False,
                train_slice_mode="random",
                val_slice_mode="exhaustive",
                action_dim=5,
            )

            self.assertEqual(train_dataset.action_dim, 5)
            self.assertEqual(val_dataset.action_dim, 5)
            self.assertEqual(len(train_dataset), 1)
            self.assertEqual(len(val_dataset), 3)

    def test_world_model_dataset_loads_video_and_action_clip(self) -> None:
        """Decode RGB frames through WorldModelDataset.__getitem__."""
        with TemporaryDirectory() as tmpdir:
            root = Path(tmpdir) / "blockworld"
            self._write_episode(root / "sunday_v2_training" / "0", [0, 1, 2, 4], write_video=True)
            self._write_episode(root / "sunday_v2_validation" / "0", [4, 2, 1, 0], write_video=True)

            _, val_dataset = create_train_val_datasets(
                dataset_name="blockworld",
                data_path_train=str(root / "sunday_v2_training"),
                data_path_val=str(root / "sunday_v2_validation"),
                num_frames=2,
                image_size=(8, 8),
                normalize_action=False,
                normalize_pixel=True,
                train_slice_mode="random",
                val_slice_mode="exhaustive",
                action_dim=5,
            )

            sample = val_dataset[0]

            self.assertEqual(tuple(sample["video"].shape), (2, 3, 8, 8))
            self.assertEqual(tuple(sample["action"].shape), (2, 5))
            self.assertEqual(sample["video"].dtype, torch.float32)
            self.assertGreaterEqual(float(sample["video"].min()), -1.0)
            self.assertLessEqual(float(sample["video"].max()), 1.0)
            self.assertTrue(torch.equal(sample["action"][0], torch.tensor([0, 0, 0, 0, 1]).float()))

    def test_decord_torch_bridge_frames_are_supported(self) -> None:
        """Decode decord torch-bridge tensors without calling NDArray.asnumpy."""
        with TemporaryDirectory() as tmpdir:
            self._write_episode(Path(tmpdir) / "sunday_v2_training" / "0", [0, 1, 2])
            source = create_data_source("blockworld", data_path=tmpdir)

            class TorchBridgeVideoReader:
                """Minimal decord VideoReader stub that returns torch tensors."""

                def __init__(self, _video_path: str, ctx: object | None = None) -> None:
                    """Accept the same arguments used by BlockWorldDataSource."""
                    self._frames = torch.full((3, 4, 5, 3), 255, dtype=torch.uint8)

                def __len__(self) -> int:
                    """Return the synthetic video length."""
                    return self._frames.shape[0]

                def get_batch(self, frame_indices: list[int]) -> torch.Tensor:
                    """Return a torch tensor like decord's torch bridge."""
                    return self._frames[frame_indices]

            decord_stub = types.SimpleNamespace(
                VideoReader=TorchBridgeVideoReader,
                cpu=lambda _index: object(),
            )
            with patch.dict(sys.modules, {"decord": decord_stub}):
                frames = source.load_visual_frames(0, 0, 2)

            self.assertEqual(tuple(frames.shape), (2, 3, 4, 5))
            self.assertTrue(torch.allclose(frames, torch.ones_like(frames)))

    def _write_episode(
        self,
        episode_dir: Path,
        actions: list[int],
        write_video: bool = False,
    ) -> None:
        """Write a minimal FlowM-style episode directory for tests."""
        episode_dir.mkdir(parents=True)
        video_path = episode_dir / "0000_rgb.mp4"
        if write_video:
            frames = []
            for i in range(len(actions)):
                frame = np.full((16, 16, 3), i * 32, dtype=np.uint8)
                frame[:, :, 0] = 255 - i * 32
                frames.append(frame)
            imageio.mimwrite(video_path, frames, fps=4, codec="libx264")
        else:
            video_path.touch()
        torch.save({"actions": torch.tensor(actions)}, episode_dir / "0000_actions.pt")


if __name__ == "__main__":
    unittest.main()

Copilot AI review requested due to automatic review settings May 3, 2026 01:40
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds first-class BlockWorld support to the dataset layer so the training/evaluation pipeline can consume FlowM 3D memory benchmark episodes through the existing WorldModelDataset abstractions.

Changes:

  • Adds a new BlockWorldDataSource for RGB MP4 episodes with .pt action sidecars and integrates it into the data source factory/export surface.
  • Adds a new Hydra dataset config group under src/configs/dataset/memory/ for BlockWorld train/validation paths and action sizing.
  • Updates repository docs and dataset references to include BlockWorld as a supported dataset family.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
src/wm_datasets/world_model_dataset.py Threads BlockWorld-specific loader kwargs through dataset creation helpers.
src/wm_datasets/data_source/memory/blockworld_data_source.py Implements the new BlockWorld filesystem/video/action data source.
src/wm_datasets/data_source/memory/__init__.py Exports the new memory data source package surface.
src/wm_datasets/data_source/factory.py Registers blockworld in the data source factory.
src/wm_datasets/data_source/__init__.py Re-exports BlockWorldDataSource from the data source package.
src/wm_datasets/__init__.py Re-exports BlockWorldDataSource at the top-level dataset package.
src/wm_datasets/README.md Documents BlockWorld in the dataset summary and loader kwargs list.
src/configs/dataset/memory/blockworld.yaml Adds the concrete BlockWorld dataset config.
src/configs/dataset/memory/base.yaml Defines the shared schema/defaults for memory dataset configs.
docs/datasets/README.md Adds BlockWorld download, layout, config, and usage documentation.
docs/config_system.md Notes the new memory dataset config family in the config system docs.
docs/README.md Updates docs index entries to include BlockWorld.
README.md Updates top-level dataset documentation references to include BlockWorld.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

"""Load explicit video paths from a file list."""
file_list_path = Path(file_list).expanduser()
if not file_list_path.is_absolute():
file_list_path = Path.cwd() / file_list_path
Comment on lines +110 to +112
stem = video_path.stem
if stem.endswith("_rgb"):
stem = stem[:-4]
"""Load and cache one episode's action sequence."""
if index not in self._action_cache:
_, _, action_path = self._episodes[index]
payload = torch.load(action_path, map_location=torch.device("cpu"), weights_only=False)
action_dim: 5

spec:
action_dim: 5
Comment on lines +79 to +83
action_path = self._action_path_from_video(video_path)
if not action_path.exists():
raise FileNotFoundError(
f"Missing BlockWorld action file for {video_path}: expected {action_path}"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants