Skip to content

[BugFix] Fix behavior or partial, nested dones in PEnv and TEnv #2959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import TensorDictModuleBase
from tensordict.utils import expand_right, NestedKey

from torchrl._utils import logger as torchrl_logger
from torchrl.data import (
Binary,
Bounded,
Expand Down Expand Up @@ -2533,3 +2533,58 @@ def __next__(self):
else:
tokens = tensors
return {"tokens": tokens, "attention_mask": tokens != 0}


class MockNestedResetEnv(EnvBase):
"""To test behaviour of envs with nested done states - where the root done prevails over others."""

def __init__(self, num_steps: int, done_at_root: bool) -> None:
super().__init__(device="cpu")
self._num_steps = num_steps
self._counter = 0
self.done_at_root = done_at_root
self.done_spec = Composite(
{
("agent_1", "done"): Binary(1, dtype=torch.bool),
("agent_2", "done"): Binary(1, dtype=torch.bool),
}
)
if done_at_root:
self.full_done_spec["done"] = Binary(1, dtype=torch.bool)

def _reset(self, tensordict: TensorDict) -> TensorDict:
torchrl_logger.info(f"Reset after {self._counter} steps!")
if tensordict is not None:
torchrl_logger.info(f"tensordict at reset {tensordict.to_dict()}")
self._counter = 0
result = TensorDict(
{
("agent_1", "done"): torch.tensor([False], dtype=torch.bool),
("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
},
)
if self.done_at_root:
result["done"] = torch.tensor([False], dtype=torch.bool)
return result

def _step(self, tensordict: TensorDict) -> TensorDict:
self._counter += 1
done = torch.tensor([self._counter >= self._num_steps], dtype=torch.bool)
if self.done_at_root:
return TensorDict(
{
"done": done,
("agent_1", "done"): torch.tensor([True], dtype=torch.bool),
("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
},
)
else:
return TensorDict(
{
("agent_1", "done"): done,
("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
},
)

def _set_seed(self):
pass
59 changes: 59 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from torchrl.envs.transforms.transforms import (
AutoResetEnv,
AutoResetTransform,
InitTracker,
Tokenizer,
Transform,
UnsqueezeTransform,
Expand Down Expand Up @@ -143,6 +144,7 @@
HistoryTransform,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
MockNestedResetEnv,
MockSerialEnv,
MultiKeyCountingEnv,
MultiKeyCountingEnvPolicy,
Expand Down Expand Up @@ -184,6 +186,7 @@
HistoryTransform,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
MockNestedResetEnv,
MockSerialEnv,
MultiKeyCountingEnv,
MultiKeyCountingEnvPolicy,
Expand Down Expand Up @@ -2925,6 +2928,62 @@ def test_nested_reset(self, nest_done, has_root_done, batch_size):
env.rollout(100)
env.rollout(100, break_when_any_done=False)

@pytest.mark.parametrize("done_at_root", [True, False])
def test_nested_partial_resets(self, maybe_fork_ParallelEnv, done_at_root):
def make_env(num_steps):
return MockNestedResetEnv(num_steps, done_at_root)

def manual_rollout(env: EnvBase, num_steps: int):
steps = []
td = env.reset()
for _ in range(num_steps):
td, next_td = env.step_and_maybe_reset(td)
steps.append(td)
td = next_td
return TensorDict.stack(steps)

# NOTE: we expect the env[0] to reset after 4 steps, env[1] to reset after 6 steps.
parallel_env = maybe_fork_ParallelEnv(
2,
create_env_fn=make_env,
create_env_kwargs=[{"num_steps": i} for i in [4, 6]],
)
transformed_env = TransformedEnv(
env=maybe_fork_ParallelEnv(
2,
create_env_fn=make_env,
create_env_kwargs=[{"num_steps": i} for i in [4, 6]],
),
transform=InitTracker(),
)

parallel_td = manual_rollout(parallel_env, 6)

transformed_td = manual_rollout(transformed_env, 6)

# We expect env[0] to have been reset and executed 2 steps.
# We expect env[1] to have just been reset (0 steps).
assert parallel_env._counter() == [2, 0]
assert transformed_env._counter() == [2, 0]
if done_at_root:
assert parallel_env._simple_done
assert transformed_env._simple_done
# We expect each env to have reached a done state once.
assert parallel_td["next", "done"].sum().item() == 2
assert transformed_td["next", "done"].sum().item() == 2
assert_allclose_td(transformed_td, parallel_td, intersection=True)
else:
assert not parallel_env._simple_done
assert not transformed_env._simple_done

assert ("next", "done") not in parallel_td
assert ("next", "done") not in transformed_td
assert parallel_td["next", "agent_1", "done"].sum().item() == 2
assert transformed_td["next", "agent_1", "done"].sum().item() == 2
assert_allclose_td(transformed_td, parallel_td, intersection=True)

assert transformed_env._counter() == [2, 0]


class TestHeteroEnvs:
@pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)])
Expand Down
18 changes: 12 additions & 6 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,9 +995,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
for elt in list_of_kwargs:
elt.update(kwargs)
if tensordict is not None:
needs_resetting = _aggregate_end_of_traj(
tensordict, reset_keys=self.reset_keys
)
if "_reset" in tensordict.keys():
needs_resetting = tensordict["_reset"]
else:
needs_resetting = _aggregate_end_of_traj(
tensordict, reset_keys=self.reset_keys
)
if needs_resetting.ndim > 2:
needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1)
if needs_resetting.ndim > 1:
Expand Down Expand Up @@ -2114,9 +2117,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
elt.update(kwargs)

if tensordict is not None:
needs_resetting = _aggregate_end_of_traj(
tensordict, reset_keys=self.reset_keys
)
if "_reset" in tensordict.keys():
needs_resetting = tensordict["_reset"]
else:
needs_resetting = _aggregate_end_of_traj(
tensordict, reset_keys=self.reset_keys
)
if needs_resetting.ndim > 2:
needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1)
if needs_resetting.ndim > 1:
Expand Down
20 changes: 15 additions & 5 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2821,12 +2821,25 @@ def _reset_check_done(self, tensordict, tensordict_reset):
# we iterate over (reset_key, (done_key, truncated_key)) and check that all
# values where reset was true now have a done set to False.
# If no reset was present, all done and truncated must be False

# Once we checked a root, we don't check its leaves - so keep track of the roots. Fortunately, we sort the done
# keys in the done_keys_group from root to leaf
prefix_complete = set()
for reset_key, done_key_group in zip(self.reset_keys, self.done_keys_groups):
skip = False
if isinstance(reset_key, tuple):
for i in range(len(reset_key) - 1):
if reset_key[:i] in prefix_complete:
skip = True
break
if skip:
continue
reset_value = (
tensordict.get(reset_key, default=None)
if tensordict is not None
else None
)
prefix_complete.add(() if isinstance(reset_key, str) else reset_key[:-1])
if reset_value is not None:
for done_key in done_key_group:
done_val = tensordict_reset.get(done_key)
Expand Down Expand Up @@ -3580,11 +3593,8 @@ def step_and_maybe_reset(
@_cache_value
def _simple_done(self):
key_set = set(self.full_done_spec.keys())
_simple_done = key_set == {
"done",
"truncated",
"terminated",
} or key_set == {"done", "terminated"}

_simple_done = "done" in key_set and "terminated" in key_set
return _simple_done

def any_done(self, tensordict: TensorDictBase) -> bool:
Expand Down
Loading