Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ mkdir -p pretrained_models/i3d && curl -L \
-o pretrained_models/i3d/i3d_torchscript.pt
```

For dataset downloads (DINO-WM, RT-1, CSGO), see [docs/datasets/README.md](docs/datasets/README.md).
For dataset downloads (DINO-WM, RT-1, CSGO, BlockWorld), see [docs/datasets/README.md](docs/datasets/README.md).

## 🥷 Train your first model

Expand Down Expand Up @@ -139,7 +139,7 @@ NanoWM rollouts can be used directly for downstream applications, including long
- **[docs/config_system.md](docs/config_system.md)** — Hydra config layout, overrides, environment variables
- **[docs/training.md](docs/training.md)** — training workflow, design choices, ablation tables, all checkpoints
- **[docs/evaluation.md](docs/evaluation.md)** — evaluation workflow, metric definitions, full result tables
- **[docs/datasets/README.md](docs/datasets/README.md)** — DINO-WM / RT-1 / CSGO formats, downloads, splits
- **[docs/datasets/README.md](docs/datasets/README.md)** — DINO-WM / RT-1 / CSGO / BlockWorld formats, downloads, splits
- **[docs/applications/planning.md](docs/applications/planning.md)** — MPC + CEM model-predictive control
- **[docs/applications/long_rollout.md](docs/applications/long_rollout.md)** — long-horizon autoregressive rollout
- **[docs/applications/video_to_3d.md](docs/applications/video_to_3d.md)** — Depth Anything 3 point cloud pipeline
Expand Down
4 changes: 2 additions & 2 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ docs/
├── training.md (training workflow + design choices + ablation tables)
├── evaluation.md (eval workflow + main result tables + sampling)
├── datasets/
│ └── README.md (DINO-WM, RT-1, CSGO formats and configs)
│ └── README.md (DINO-WM, RT-1, CSGO, BlockWorld formats and configs)
└── applications/
├── planning.md (MPC + CEM model-predictive control)
├── long_rollout.md (long-horizon autoregressive rollout)
Expand All @@ -23,7 +23,7 @@ docs/
- **[Configuration system](config_system.md)** — Hydra layout, composition, environment variables, common overrides, debugging.
- **[Training](training.md)** — workflow + the four design axes (prediction target, action injection, model scale, EMA) with ablation tables and pretrained checkpoints.
- **[Evaluation](evaluation.md)** — `experiment=evaluate_only`, scheduling modes, metric definitions, headline numbers on each domain.
- **[Datasets](datasets/README.md)** — DINO-WM (5 envs), RT-1 fractal, CSGO. Download / split / format / config.
- **[Datasets](datasets/README.md)** — DINO-WM (5 envs), RT-1 fractal, CSGO, BlockWorld. Download / split / format / config.
- **[Planning](applications/planning.md)** — MPC + CEM over the diffusion world model. point_maze and PushT recipes.
- **[Long rollout](applications/long_rollout.md)** — 50-frame autoregressive rollout with sliding context window. CSGO demo.
- **[Video → 3D point cloud](applications/video_to_3d.md)** — DA3 multi-view depth + viser viewer.
Expand Down
6 changes: 6 additions & 0 deletions docs/config_system.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ Each dataset family has a `base.yaml` that fixes the schema, plus per-dataset ov
# action_dim=51 (keys + mouse), normalize_action=False
```

**Memory** (`src/configs/dataset/memory/`):
```yaml
# blockworld.yaml — FlowM 3D BlockWorld memory benchmark,
# train_slice_mode=random, action_dim=5 one-hot actions
```

**RT-1** (`src/configs/dataset/rt1/`):
```yaml
# rt1.yaml — LeRobot HF dataset (IPEC-COMMUNITY/fractal20220817_data_lerobot),
Expand Down
64 changes: 63 additions & 1 deletion docs/datasets/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Datasets

The repo ships configs for three dataset families: **DINO-WM** (5 simulated environments), **RT-1 fractal** (real-robot LeRobot), and **CSGO** (Counter-Strike deathmatch). All datasets feed through a common `WorldModelDataset` interface — see `src/wm_datasets/`.
The repo ships configs for four dataset families: **DINO-WM** (5 simulated environments), **RT-1 fractal** (real-robot LeRobot), **CSGO** (Counter-Strike deathmatch), and **BlockWorld** (FlowM 3D memory benchmark). All datasets feed through a common `WorldModelDataset` interface — see `src/wm_datasets/`.

<div align="center">

Expand All @@ -16,6 +16,8 @@ export CSGO_DATA_DIR=/path/to/csgo # CSGO root (HDF5 files)
export RT1_DATA_ROOT=/path/to/rt1_fractal # RT-1 (LeRobot mirror)
```

For BlockWorld, set `DATASET_DIR` to the parent directory containing `blockworld/`.

Or use the gitignored `src/configs/local/paths.yaml` (template at `paths.yaml.example`). See [config_system.md](../config_system.md#path-configuration).

## At a glance
Expand All @@ -31,6 +33,7 @@ Or use the gitignored `src/configs/local/paths.yaml` (template at `paths.yaml.ex
| DINO-WM Granular | ~500 | ~100–200 | 256² | 2 | exhaustive | deformable |
| RT-1 (fractal) | 87k | ~40–60 | 256² | 7 | random | LeRobot v2.0, frame_interval=1 |
| CSGO | 5500 | 1000 | 320×512 | 51 | random | fixed val start indices, frame_interval=1 |
| BlockWorld | config-dependent | 70–140 | 128² | 5 | random | FlowM 3D memory benchmark |

</div>

Expand Down Expand Up @@ -201,6 +204,65 @@ The shipped checkpoints are NanoWM-L/2 trained for 50k or 100k steps. See [train

---

## BlockWorld

FlowM's 3D Dynamic BlockWorld memory benchmark. Source: [hlillemark/flowm](https://github.com/hlillemark/flowm). The dataset uses RGB videos plus per-frame discrete actions; nano-world-model exposes those actions as 5-D one-hot vectors.

### Download

Use the FlowM dataset downloader and place the extracted dataset under `${DATASET_DIR}/blockworld`:

```bash
cd /path/to/flowm
bash ./download_datasets.sh --dataset blockworld --configs dynamic --splits train,validation
```

### Data format

Expected layout:

```text
${DATASET_DIR}/blockworld/
├── sunday_v2_training/
│ └── 0/
│ ├── 0000_rgb.mp4
│ ├── 0000_depth.mp4
│ └── 0000_actions.pt
└── sunday_v2_validation/
```

The `*_actions.pt` sidecar must contain either an `actions` tensor of integer action ids or an already-vectorized `[T, 5]` tensor. Depth videos are left untouched; the current integration is RGB-only to match the existing NanoWM video/action interface.

### Config

`src/configs/dataset/memory/blockworld.yaml`:

```yaml
name: "blockworld"
frame_interval: 1
loader:
data_path_train: "${dataset_dir}/blockworld/sunday_v2_training"
data_path_val: "${dataset_dir}/blockworld/sunday_v2_validation"
normalize_action: False
train_slice_mode: "random"
val_slice_mode: "exhaustive"
spec:
action_dim: 5
```

### Train command

```bash
python src/main.py dataset=memory/blockworld model=nanowm_b2
```

### Memory

- Actions are small and cached on first trajectory access.
- Frames are decoded on demand from MP4 and are never cached by the data source.

---

## Adding a new dataset

1. **DataSource** in `src/wm_datasets/data_source/`:
Expand Down
20 changes: 20 additions & 0 deletions src/configs/dataset/memory/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# @package dataset
loader:
n_rollout: null
data_path_train: null
data_path_val: null
split_ratio: 0.9
validation_size: 32
normalize_state: False
normalize_action: False
train_slice_mode: "random"
val_slice_mode: "exhaustive"
stride: 1
random_seed: 42
validation_fixed_subset_path: null
validation_fixed_subset_size: null
validation_fixed_subset_seed: 42
file_list: null
action_dim: 5
video_suffix: "_rgb.mp4"
action_suffix: "_actions.pt"
17 changes: 17 additions & 0 deletions src/configs/dataset/memory/blockworld.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
defaults:
- memory/base

name: "blockworld"
frame_interval: 1

loader:
data_path_train: "${dataset_dir}/blockworld/sunday_v2_training"
data_path_val: "${dataset_dir}/blockworld/sunday_v2_validation"
normalize_action: False
resize_mode: "stretch"
train_slice_mode: "random"
val_slice_mode: "exhaustive"
action_dim: 5

spec:
action_dim: 5
2 changes: 2 additions & 0 deletions src/wm_datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ train_dataset, val_dataset = create_train_val_datasets(
| **pusht** | 18,685 | 5 | 2 | ~109 |
| **rope** | TBD | TBD | TBD | TBD |
| **granular** | TBD | TBD | TBD | TBD |
| **blockworld** | config-dependent | 0 | 5 | 70–140 |

</div>

Expand Down Expand Up @@ -130,6 +131,7 @@ def create_world_model_dataset(
- `use_relative_actions` (pusht): Use relative vs absolute actions
- `action_scale` (pusht): Action scaling factor
- `object_name` (deformable): "rope" or "granular"
- `action_dim` (blockworld): Number of discrete actions for one-hot encoding

### Sampling Modes

Expand Down
2 changes: 2 additions & 0 deletions src/wm_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .data_source import (
DataSource,
BlockWorldDataSource,
DinoWorldModelDataSource,
LeRobotDataSource,
TrajectoryData,
Expand All @@ -25,6 +26,7 @@
__all__ = [
# DataSource layer
"DataSource",
"BlockWorldDataSource",
"DinoWorldModelDataSource",
"LeRobotDataSource",
"TrajectoryData",
Expand Down
2 changes: 2 additions & 0 deletions src/wm_datasets/data_source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base import DataSource, TrajectoryData
from .dino_wm import DinoWorldModelDataSource, PushTDataSource, DeformableEnvDataSource
from .lerobot import LeRobotDataSource
from .memory import BlockWorldDataSource
from .factory import create_data_source

__all__ = [
Expand All @@ -14,5 +15,6 @@
"PushTDataSource",
"DeformableEnvDataSource",
"LeRobotDataSource",
"BlockWorldDataSource",
"create_data_source",
]
14 changes: 13 additions & 1 deletion src/wm_datasets/data_source/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .dino_wm import DinoWorldModelDataSource, PushTDataSource, DeformableEnvDataSource
from .lerobot import LeRobotDataSource
from .game import CSGODataSource
from .memory import BlockWorldDataSource
from .base import DataSource


Expand Down Expand Up @@ -72,7 +73,18 @@ def create_data_source(
**csgo_kwargs
)

if dataset_name == "blockworld":
blockworld_kwargs = {
k: v for k, v in kwargs.items()
if k in ['action_dim', 'file_list', 'video_suffix', 'action_suffix']
}
return BlockWorldDataSource(
data_path=data_path,
n_rollout=n_rollout,
**blockworld_kwargs,
)

raise ValueError(
f"Unknown dataset: {dataset_name}. "
f"Supported: point_maze, wall, pusht, rope, granular, rt1, csgo"
f"Supported: point_maze, wall, pusht, rope, granular, rt1, csgo, blockworld"
)
7 changes: 7 additions & 0 deletions src/wm_datasets/data_source/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
Memory benchmark dataset sources.
"""

from .blockworld_data_source import BlockWorldDataSource

__all__ = ["BlockWorldDataSource"]
Loading
Loading