Skip to content

Commit efa534f

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 955c0af + 26b5c9c commit efa534f

File tree

21 files changed

+172
-15
lines changed

21 files changed

+172
-15
lines changed

test/test_cost.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from tensordict.nn.utils import Buffer
4141
from tensordict.utils import unravel_key
4242
from torch import autograd, nn
43+
from torchrl._utils import _standardize
4344
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
4445
from torchrl.data.postprocs.postprocs import MultiStep
4546
from torchrl.envs.model_based.dreamer import DreamerEnv
@@ -16044,6 +16045,18 @@ def _composite_log_prob(self):
1604416045
yield
1604516046
setter.unset()
1604616047

16048+
def test_standardization(self):
16049+
t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6)
16050+
std_t0 = _standardize(t, exclude_dims=(1, 3))
16051+
std_t1 = (t - t.mean((0, 2), keepdim=True)) / t.std((0, 2), keepdim=True).clamp(
16052+
1 - 6
16053+
)
16054+
torch.testing.assert_close(std_t0, std_t1)
16055+
std_t = _standardize(t, (), -1, 2)
16056+
torch.testing.assert_close(std_t, (t + 1) / 2)
16057+
std_t = _standardize(t, ())
16058+
torch.testing.assert_close(std_t, (t - t.mean()) / t.std())
16059+
1604716060
@pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip
1604816061
@pytest.mark.parametrize("T", [1, 10])
1604916062
@pytest.mark.parametrize("device", get_default_devices())

torchrl/_utils.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@
2424
from distutils.util import strtobool
2525
from functools import wraps
2626
from importlib import import_module
27-
from typing import Any, Callable, cast, Dict, TypeVar, Union
27+
from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union
2828

2929
import numpy as np
3030
import torch
3131
from packaging.version import parse
3232
from tensordict import unravel_key
3333

3434
from tensordict.utils import NestedKey
35-
from torch import multiprocessing as mp
35+
from torch import multiprocessing as mp, Tensor
3636

3737
try:
3838
from torch.compiler import is_compiling
@@ -872,6 +872,70 @@ def set_mode(self, type: Any | None) -> None:
872872
self._mode = type
873873

874874

875+
def _standardize(
876+
input: Tensor,
877+
exclude_dims: Tuple[int] = (),
878+
mean: Tensor | None = None,
879+
std: Tensor | None = None,
880+
eps: float | None = None,
881+
):
882+
"""Standardizes the input tensor with the possibility of excluding specific dims from the statistics.
883+
884+
Useful when processing multi-agent data to keep the agent dimensions independent.
885+
886+
Args:
887+
input (Tensor): the input tensor to be standardized.
888+
exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: ().
889+
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
890+
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
891+
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.
892+
893+
"""
894+
if eps is None:
895+
if input.dtype.is_floating_point:
896+
eps = torch.finfo(torch.float).resolution
897+
else:
898+
eps = 1e-6
899+
900+
len_exclude_dims = len(exclude_dims)
901+
if not len_exclude_dims:
902+
if mean is None:
903+
mean = input.mean()
904+
else:
905+
# Assume dtypes are compatible
906+
mean = torch.as_tensor(mean, device=input.device)
907+
if std is None:
908+
std = input.std()
909+
else:
910+
# Assume dtypes are compatible
911+
std = torch.as_tensor(std, device=input.device)
912+
return (input - mean) / std.clamp_min(eps)
913+
914+
input_shape = input.shape
915+
exclude_dims = [
916+
d if d >= 0 else d + len(input_shape) for d in exclude_dims
917+
] # Make negative dims positive
918+
919+
if len(set(exclude_dims)) != len_exclude_dims:
920+
raise ValueError("Exclude dims has repeating elements")
921+
if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims):
922+
raise ValueError(
923+
f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}"
924+
)
925+
if len_exclude_dims == len(input_shape):
926+
warnings.warn(
927+
"_standardize called but all dims were excluded from the statistics, returning unprocessed input"
928+
)
929+
return input
930+
931+
included_dims = tuple(d for d in range(len(input_shape)) if d not in exclude_dims)
932+
if mean is None:
933+
mean = torch.mean(input, keepdim=True, dim=included_dims)
934+
if std is None:
935+
std = torch.std(input, keepdim=True, dim=included_dims)
936+
return (input - mean) / std.clamp_min(eps)
937+
938+
875939
@wraps(torch.compile)
876940
def compile_with_warmup(*args, warmup: int = 1, **kwargs):
877941
"""Compile a model with warm-up.

torchrl/envs/batched_envs.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,10 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
12171217
__doc__ += BatchedEnvBase.__doc__
12181218
__doc__ += """
12191219
1220+
.. note:: ParallelEnv will timeout after one of the worker is idle for a determinate amount of time.
1221+
This can be controlled via the BATCHED_PIPE_TIMEOUT environment variable, which in turn modifies
1222+
the torchrl._utils.BATCHED_PIPE_TIMEOUT integer. The default timeout value is 10000 seconds.
1223+
12201224
.. warning::
12211225
TorchRL's ParallelEnv is quite stringent when it comes to env specs, since
12221226
these are used to build shared memory buffers for inter-process communication.
@@ -1353,7 +1357,10 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
13531357
"""
13541358

13551359
def _start_workers(self) -> None:
1360+
import torchrl
1361+
13561362
self._timeout = 10.0
1363+
self.BATCHED_PIPE_TIMEOUT = torchrl._utils.BATCHED_PIPE_TIMEOUT
13571364

13581365
from torchrl.envs.env_creator import EnvCreator
13591366

@@ -1606,7 +1613,7 @@ def step_and_maybe_reset(
16061613

16071614
for i in workers_range:
16081615
event = self._events[i]
1609-
event.wait(self._timeout)
1616+
event.wait(self.BATCHED_PIPE_TIMEOUT)
16101617
event.clear()
16111618

16121619
if self._non_tensor_keys:
@@ -1796,7 +1803,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
17961803

17971804
for i in workers_range:
17981805
event = self._events[i]
1799-
event.wait(self._timeout)
1806+
event.wait(self.BATCHED_PIPE_TIMEOUT)
18001807
event.clear()
18011808

18021809
if self._non_tensor_keys:
@@ -1965,7 +1972,7 @@ def tentative_update(val, other):
19651972

19661973
for i, _ in outs:
19671974
event = self._events[i]
1968-
event.wait(self._timeout)
1975+
event.wait(self.BATCHED_PIPE_TIMEOUT)
19691976
event.clear()
19701977

19711978
workers_nontensor = []
@@ -2023,7 +2030,7 @@ def _shutdown_workers(self) -> None:
20232030
for channel in self.parent_channels:
20242031
channel.close()
20252032
for proc in self._workers:
2026-
proc.join(timeout=1.0)
2033+
proc.join(timeout=self._timeout)
20272034
finally:
20282035
for proc in self._workers:
20292036
if proc.is_alive():

torchrl/modules/tensordict_module/probabilistic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
56

67
import warnings
78
from typing import Dict, List, Optional, Type, Union

torchrl/objectives/a2c.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def __post_init__(self):
257257
self.sample_log_prob = "action_log_prob"
258258

259259
default_keys = _AcceptedKeys
260+
tensor_keys: _AcceptedKeys
260261
default_value_estimator: ValueEstimators = ValueEstimators.GAE
261262

262263
actor_network: TensorDictModule

torchrl/objectives/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class _AcceptedKeys:
128128

129129
pass
130130

131+
tensor_keys: _AcceptedKeys
131132
_vmap_randomness = None
132133
default_value_estimator: ValueEstimators = None
133134

torchrl/objectives/cql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ class _AcceptedKeys:
260260
done: NestedKey = "done"
261261
terminated: NestedKey = "terminated"
262262

263+
tensor_keys: _AcceptedKeys
263264
default_keys = _AcceptedKeys
264265
default_value_estimator = ValueEstimators.TD0
265266

@@ -1024,6 +1025,7 @@ class _AcceptedKeys:
10241025
terminated: NestedKey = "terminated"
10251026
pred_val: NestedKey = "pred_val"
10261027

1028+
tensor_keys: _AcceptedKeys
10271029
default_keys = _AcceptedKeys
10281030
default_value_estimator = ValueEstimators.TD0
10291031
out_keys = [

torchrl/objectives/crossq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class _AcceptedKeys:
242242
terminated: NestedKey = "terminated"
243243
log_prob: NestedKey = "_log_prob"
244244

245+
tensor_keys: _AcceptedKeys
245246
default_keys = _AcceptedKeys
246247
default_value_estimator = ValueEstimators.TD0
247248

torchrl/objectives/ddpg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class _AcceptedKeys:
173173
done: NestedKey = "done"
174174
terminated: NestedKey = "terminated"
175175

176+
tensor_keys: _AcceptedKeys
176177
default_keys = _AcceptedKeys
177178
default_value_estimator: ValueEstimators = ValueEstimators.TD0
178179
out_keys = [

torchrl/objectives/decision_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class _AcceptedKeys:
7070
# the "action" output from the model
7171
action_pred: NestedKey = "action"
7272

73+
tensor_keys: _AcceptedKeys
7374
default_keys = _AcceptedKeys
7475

7576
actor_network: TensorDictModule
@@ -280,6 +281,7 @@ class _AcceptedKeys:
280281
# the "action" output from the model
281282
action_pred: NestedKey = "action"
282283

284+
tensor_keys: _AcceptedKeys
283285
default_keys = _AcceptedKeys
284286

285287
actor_network: TensorDictModule

torchrl/objectives/dqn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class _AcceptedKeys:
164164
done: NestedKey = "done"
165165
terminated: NestedKey = "terminated"
166166

167+
tensor_keys: _AcceptedKeys
167168
default_keys = _AcceptedKeys
168169
default_value_estimator = ValueEstimators.TD0
169170
out_keys = ["loss"]
@@ -435,6 +436,7 @@ class _AcceptedKeys:
435436
terminated: NestedKey = "terminated"
436437
steps_to_next_obs: NestedKey = "steps_to_next_obs"
437438

439+
tensor_keys: _AcceptedKeys
438440
default_keys = _AcceptedKeys
439441
default_value_estimator = ValueEstimators.TD0
440442

torchrl/objectives/dreamer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,13 @@ class _AcceptedKeys:
8989
pixels: NestedKey = "pixels"
9090
reco_pixels: NestedKey = "reco_pixels"
9191

92+
tensor_keys: _AcceptedKeys
9293
default_keys = _AcceptedKeys
9394

95+
decoder: TensorDictModule
96+
reward_model: TensorDictModule
97+
world_mdel: TensorDictModule
98+
9499
def __init__(
95100
self,
96101
world_model: TensorDictModule,
@@ -238,9 +243,13 @@ class _AcceptedKeys:
238243
done: NestedKey = "done"
239244
terminated: NestedKey = "terminated"
240245

246+
tensor_keys: _AcceptedKeys
241247
default_keys = _AcceptedKeys
242248
default_value_estimator = ValueEstimators.TDLambda
243249

250+
value_model: TensorDictModule
251+
actor_model: TensorDictModule
252+
244253
def __init__(
245254
self,
246255
actor_model: TensorDictModule,
@@ -392,8 +401,11 @@ class _AcceptedKeys:
392401

393402
value: NestedKey = "state_value"
394403

404+
tensor_keys: _AcceptedKeys
395405
default_keys = _AcceptedKeys
396406

407+
value_model: TensorDictModule
408+
397409
def __init__(
398410
self,
399411
value_model: TensorDictModule,

torchrl/objectives/gail.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class _AcceptedKeys:
5959
collector_observation: NestedKey = "collector_observation"
6060
discriminator_pred: NestedKey = "d_logits"
6161

62+
tensor_keys: _AcceptedKeys
6263
default_keys = _AcceptedKeys
6364

6465
discriminator_network: TensorDictModule

torchrl/objectives/iql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ class _AcceptedKeys:
233233
done: NestedKey = "done"
234234
terminated: NestedKey = "terminated"
235235

236+
tensor_keys: _AcceptedKeys
236237
default_keys = _AcceptedKeys
237238
default_value_estimator = ValueEstimators.TD0
238239
out_keys = [
@@ -709,6 +710,7 @@ class _AcceptedKeys:
709710
done: NestedKey = "done"
710711
terminated: NestedKey = "terminated"
711712

713+
tensor_keys: _AcceptedKeys
712714
default_keys = _AcceptedKeys
713715
default_value_estimator = ValueEstimators.TD0
714716
out_keys = [

torchrl/objectives/multiagent/qmixer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class _AcceptedKeys:
179179
done: NestedKey = "done"
180180
terminated: NestedKey = "terminated"
181181

182+
tensor_keys: _AcceptedKeys
182183
default_keys = _AcceptedKeys
183184
default_value_estimator = ValueEstimators.TD0
184185
out_keys = ["loss"]

0 commit comments

Comments
 (0)