Description
Describe the bug
When a transform is applied to a ParallelEnv, the resetting behavior of the ParallelEnv will change when in a multi-agent setting.
This causes all sub-environments in the ParallelEnv to reset as soon as one of them is done (and the agent-level done fields are True).
To Reproduce
The reproduction is straightforward, see the code snippet below. This involves:
- Creating an environment with global and agent-level done keys
- Setting one of the agent-level done keys to True
- Setting the done field of one of the sub-environments to True
- ALL sub-environments will be reset, even though only one of them is done.
Note that the ParallelEnv
with two sub-environments will correctly reset after 4 and 6 steps (as expected).
However, the transformed ParallelEnv
will reset all its environments after 4 steps (not expected)!
from torchrl.envs import TransformedEnv, ParallelEnv, EnvBase, InitTracker
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.tensor_specs import Binary, Composite
class MockEnv(EnvBase):
def __init__(self, num_steps: int) -> None:
super().__init__(device="cpu")
self._num_steps = num_steps
self._counter = 0
self.done_spec = Composite(
{
"done": Binary(1, dtype=torch.bool),
("agent_1", "done"): Binary(1, dtype=torch.bool),
("agent_2", "done"): Binary(1, dtype=torch.bool),
}
)
def _reset(self, tensordict: TensorDict) -> TensorDict:
print(f"Reset after {self._counter} steps!")
self._counter = 0
return TensorDict(
{
"done": torch.tensor([False], dtype=torch.bool),
("agent_1", "done"): torch.tensor([False], dtype=torch.bool),
("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
def _step(self, tensordict: TensorDict) -> TensorDict:
self._counter += 1
done = torch.tensor([self._counter >= self._num_steps], dtype=torch.bool)
return TensorDict(
{
"done": done,
# NOTE: one of the agent-level done fields must be True for the bug to trigger.
("agent_1", "done"): torch.tensor([True], dtype=torch.bool),
("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
def _set_seed():
pass
if __name__ == "__main__":
def make_env(num_steps):
return MockEnv(num_steps)
def manual_rollout(env: EnvBase, num_steps: int):
steps = []
td = env.reset()
for _ in range(num_steps):
td, next_td = env.step_and_maybe_reset(td)
steps.append(td)
td = next_td
return TensorDict.stack(steps)
# NOTE: we expect the env[0] to reset after 4 steps, env[1] to reset after 6 steps.
parallel_env = ParallelEnv(2, create_env_fn=make_env, create_env_kwargs=[{"num_steps": i} for i in [4, 6]])
transformed_env = TransformedEnv(
env=ParallelEnv(2, create_env_fn=make_env, create_env_kwargs=[{"num_steps": i} for i in [4, 6]]),
transform=InitTracker(),
)
print("ParallelEnv")
parallel_td = manual_rollout(parallel_env, 6)
# Will print:
# ParallelEnv
# Reset after 0 steps! (env[0])
# Reset after 0 steps! (env[1])
# Reset after 4 steps! (env[0])
# Reset after 6 steps! (env[1])
print("TransformedEnv(ParallelEnv)")
transformed_td = manual_rollout(transformed_env, 6)
# Will print
# TransformedEnv(ParallelEnv)
# Reset after 0 steps! (env[0])
# Reset after 0 steps! (env[1])
# Reset after 4 steps! (env[0])
# Reset after 4 steps! (env[1]) <---- BUG: why was this environment reset???
# We expect each env to have reached a done state once.
assert parallel_td["next", "done"].sum().item() == 2
# We expect env[0] to have been reset and executed 2 steps.
# We expect env[1] to have just been reset (0 steps).
assert parallel_env._counter() == [2, 0]
assert parallel_td["next", "done"].sum().item() == 2
# We expect each env to have reached a done state once.
assert transformed_td["next", "done"].sum().item() == 2
# We expect env[0] to have been reset and executed 2 steps.
# We expect env[1] to have just been reset (0 steps).
# We only expect env[0] to have reached a done state.
# BUG: the done flag is not set, but the environment is reset!
assert transformed_env._counter() == [2, 0]
Reason and Possible fixes
The ParallelEnv
has a pretty complex step_and_maybe_reset
method. However, when a ParallelEnv
is transformed, this step_and_maybe_reset
is not called.
Instead, the EnvBase.step_and_maybe_reset
is called, which presumably does not work the same way.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)