Skip to content

[BUG] Transforming a ParallelEnv causes all sub-envs to be reset when one of them is done in a multi-agent setting #2958

Closed
@thomasbbrunner

Description

@thomasbbrunner

Describe the bug

When a transform is applied to a ParallelEnv, the resetting behavior of the ParallelEnv will change when in a multi-agent setting.

This causes all sub-environments in the ParallelEnv to reset as soon as one of them is done (and the agent-level done fields are True).

To Reproduce

The reproduction is straightforward, see the code snippet below. This involves:

  1. Creating an environment with global and agent-level done keys
  2. Setting one of the agent-level done keys to True
  3. Setting the done field of one of the sub-environments to True
  4. ALL sub-environments will be reset, even though only one of them is done.

Note that the ParallelEnv with two sub-environments will correctly reset after 4 and 6 steps (as expected).

However, the transformed ParallelEnv will reset all its environments after 4 steps (not expected)!

from torchrl.envs import TransformedEnv, ParallelEnv, EnvBase, InitTracker
import torch

from tensordict.tensordict import TensorDict
from torchrl.data.tensor_specs import Binary, Composite


class MockEnv(EnvBase):
    def __init__(self, num_steps: int) -> None:
        super().__init__(device="cpu")
        self._num_steps = num_steps
        self._counter = 0
        self.done_spec = Composite(
            {
                "done": Binary(1, dtype=torch.bool),
                ("agent_1", "done"): Binary(1, dtype=torch.bool),
                ("agent_2", "done"): Binary(1, dtype=torch.bool),
            }
        )

    def _reset(self, tensordict: TensorDict) -> TensorDict:
        print(f"Reset after {self._counter} steps!")
        self._counter = 0
        return TensorDict(
            {
                "done": torch.tensor([False], dtype=torch.bool),
                ("agent_1", "done"): torch.tensor([False], dtype=torch.bool),
                ("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
            },
            batch_size=[],
        )

    def _step(self, tensordict: TensorDict) -> TensorDict:
        self._counter += 1
        done = torch.tensor([self._counter >= self._num_steps], dtype=torch.bool)
        return TensorDict(
            {
                "done": done,
                # NOTE: one of the agent-level done fields must be True for the bug to trigger.
                ("agent_1", "done"): torch.tensor([True], dtype=torch.bool),
                ("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
            },
            batch_size=[],
        )

    def _set_seed():
        pass


if __name__ == "__main__":

    def make_env(num_steps):
        return MockEnv(num_steps)

    def manual_rollout(env: EnvBase, num_steps: int):
        steps = []
        td = env.reset()
        for _ in range(num_steps):
            td, next_td = env.step_and_maybe_reset(td)
            steps.append(td)
            td = next_td
        return TensorDict.stack(steps)

    # NOTE: we expect the env[0] to reset after 4 steps, env[1] to reset after 6 steps.
    parallel_env = ParallelEnv(2, create_env_fn=make_env, create_env_kwargs=[{"num_steps": i} for i in [4, 6]])
    transformed_env = TransformedEnv(
        env=ParallelEnv(2, create_env_fn=make_env, create_env_kwargs=[{"num_steps": i} for i in [4, 6]]),
        transform=InitTracker(),
    )

    print("ParallelEnv")
    parallel_td = manual_rollout(parallel_env, 6)
    # Will print:
    # ParallelEnv
    # Reset after 0 steps! (env[0])
    # Reset after 0 steps! (env[1])
    # Reset after 4 steps! (env[0])
    # Reset after 6 steps! (env[1])

    print("TransformedEnv(ParallelEnv)")
    transformed_td = manual_rollout(transformed_env, 6)
    # Will print
    # TransformedEnv(ParallelEnv)
    # Reset after 0 steps! (env[0])
    # Reset after 0 steps! (env[1])
    # Reset after 4 steps! (env[0])
    # Reset after 4 steps! (env[1])     <---- BUG: why was this environment reset???

    # We expect each env to have reached a done state once.
    assert parallel_td["next", "done"].sum().item() == 2
    # We expect env[0] to have been reset and executed 2 steps.
    # We expect env[1] to have just been reset (0 steps).
    assert parallel_env._counter() == [2, 0]
    assert parallel_td["next", "done"].sum().item() == 2

    # We expect each env to have reached a done state once.
    assert transformed_td["next", "done"].sum().item() == 2
    # We expect env[0] to have been reset and executed 2 steps.
    # We expect env[1] to have just been reset (0 steps).
    # We only expect env[0] to have reached a done state.
    # BUG: the done flag is not set, but the environment is reset!
    assert transformed_env._counter() == [2, 0]

Reason and Possible fixes

The ParallelEnv has a pretty complex step_and_maybe_reset method. However, when a ParallelEnv is transformed, this step_and_maybe_reset is not called.

Instead, the EnvBase.step_and_maybe_reset is called, which presumably does not work the same way.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions