Skip to content

Commit 63a1457

Browse files
author
Vincent Moens
committed
[Feature] use_vmap=False for SAC
ghstack-source-id: d66b53a Pull Request resolved: #2392
1 parent 25e8bd2 commit 63a1457

File tree

4 files changed

+167
-10
lines changed

4 files changed

+167
-10
lines changed

test/test_cost.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3650,6 +3650,7 @@ def _create_seq_mock_data_sac(
36503650
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
36513651
@pytest.mark.parametrize("device", get_default_devices())
36523652
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
3653+
@pytest.mark.parametrize("use_vmap", [False, True])
36533654
def test_sac(
36543655
self,
36553656
delay_value,
@@ -3659,6 +3660,7 @@ def test_sac(
36593660
device,
36603661
version,
36613662
td_est,
3663+
use_vmap,
36623664
):
36633665
if (delay_actor or delay_qvalue) and not delay_value:
36643666
pytest.skip("incompatible config")
@@ -3687,6 +3689,7 @@ def test_sac(
36873689
value_network=value,
36883690
num_qvalue_nets=num_qvalue,
36893691
loss_function="l2",
3692+
use_vmap=use_vmap,
36903693
**kwargs,
36913694
)
36923695

@@ -3811,6 +3814,68 @@ def test_sac(
38113814
p.grad is None or p.grad.norm() == 0.0
38123815
), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
38133816

3817+
@pytest.mark.parametrize("device", get_default_devices())
3818+
def test_sac_vmap_equiv(
3819+
self,
3820+
device,
3821+
version,
3822+
delay_value=True,
3823+
delay_actor=True,
3824+
delay_qvalue=True,
3825+
num_qvalue=4,
3826+
td_est=None,
3827+
):
3828+
if (delay_actor or delay_qvalue) and not delay_value:
3829+
pytest.skip("incompatible config")
3830+
3831+
torch.manual_seed(self.seed)
3832+
td = self._create_mock_data_sac(device=device)
3833+
3834+
actor = self._create_mock_actor(device=device)
3835+
qvalue = self._create_mock_qvalue(device=device)
3836+
if version == 1:
3837+
value = self._create_mock_value(device=device)
3838+
else:
3839+
value = None
3840+
3841+
kwargs = {}
3842+
if delay_actor:
3843+
kwargs["delay_actor"] = True
3844+
if delay_qvalue:
3845+
kwargs["delay_qvalue"] = True
3846+
if delay_value:
3847+
kwargs["delay_value"] = True
3848+
3849+
loss_fn_vmap = SACLoss(
3850+
actor_network=actor,
3851+
qvalue_network=qvalue,
3852+
value_network=value,
3853+
num_qvalue_nets=num_qvalue,
3854+
loss_function="l2",
3855+
use_vmap=True,
3856+
**kwargs,
3857+
)
3858+
loss_fn_novmap = SACLoss(
3859+
actor_network=actor,
3860+
qvalue_network=qvalue,
3861+
value_network=value,
3862+
num_qvalue_nets=num_qvalue,
3863+
loss_function="l2",
3864+
use_vmap=False,
3865+
**kwargs,
3866+
)
3867+
loss_fn_novmap.load_state_dict(loss_fn_vmap.state_dict())
3868+
3869+
with torch.no_grad(), _check_td_steady(td), pytest.warns(
3870+
UserWarning, match="No target network updater"
3871+
):
3872+
rng_state = torch.random.get_rng_state()
3873+
loss_vmap = loss_fn_vmap(td.clone())
3874+
torch.random.set_rng_state(rng_state)
3875+
loss_novmap = loss_fn_novmap(td.clone())
3876+
3877+
assert_allclose_td(loss_vmap, loss_novmap)
3878+
38143879
@pytest.mark.parametrize("delay_value", (True, False))
38153880
@pytest.mark.parametrize("delay_actor", (True, False))
38163881
@pytest.mark.parametrize("delay_qvalue", (True, False))

torchrl/objectives/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class _AcceptedKeys:
108108

109109
_vmap_randomness = None
110110
default_value_estimator: ValueEstimators = None
111+
use_vmap: bool = True
111112

112113
deterministic_sampling_mode: ExplorationType = ExplorationType.DETERMINISTIC
113114

torchrl/objectives/sac.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torchrl.objectives.utils import (
2929
_cache_values,
3030
_GAMMA_LMBDA_DEPREC_ERROR,
31+
_LoopVmapModule,
3132
_reduce,
3233
_vmap_func,
3334
default_value_kwargs,
@@ -113,6 +114,9 @@ class SACLoss(LossModule):
113114
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
114115
``"mean"``: the sum of the output will be divided by the number of
115116
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
117+
use_vmap (bool, optional): Whether :func:`~torch.vmap` should be used to batch
118+
operations. Defaults to ``True``.
119+
.. note:: Not using ``vmap`` offers greater flexibility but may incur a slower runtime.
116120
117121
Examples:
118122
>>> import torch
@@ -307,7 +311,9 @@ def __init__(
307311
priority_key: str = None,
308312
separate_losses: bool = False,
309313
reduction: str = None,
314+
use_vmap: bool = True,
310315
) -> None:
316+
self.use_vmap = use_vmap
311317
self._in_keys = None
312318
self._out_keys = None
313319
if reduction is None:
@@ -407,13 +413,22 @@ def __init__(
407413
self.reduction = reduction
408414

409415
def _make_vmap(self):
410-
self._vmap_qnetworkN0 = _vmap_func(
411-
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
412-
)
413-
if self._version == 1:
414-
self._vmap_qnetwork00 = _vmap_func(
415-
self.qvalue_network, randomness=self.vmap_randomness
416+
if self.use_vmap:
417+
self._vmap_qnetworkN0 = _vmap_func(
418+
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
416419
)
420+
if self._version == 1:
421+
self._vmap_qnetwork00 = _vmap_func(
422+
self.qvalue_network, randomness=self.vmap_randomness
423+
)
424+
else:
425+
self._vmap_qnetworkN0 = _LoopVmapModule(
426+
self.qvalue_network, (None, 0), functional=True
427+
)
428+
if self._version == 1:
429+
self._vmap_qnetwork00 = _LoopVmapModule(
430+
self.qvalue_network, functional=True
431+
)
417432

418433
@property
419434
def target_entropy_buffer(self):
@@ -579,7 +594,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
579594
else:
580595
loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape)
581596
loss_value = None
582-
loss_actor, metadata_actor = self._actor_loss(tensordict_reshape)
597+
loss_actor, metadata_actor = self.actor_loss(tensordict_reshape)
583598
loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"])
584599
tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"])
585600
if (loss_actor.shape != loss_qvalue.shape) or (
@@ -614,9 +629,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
614629
def _cached_detached_qvalue_params(self):
615630
return self.qvalue_network_params.detach()
616631

617-
def _actor_loss(
632+
def actor_loss(
618633
self, tensordict: TensorDictBase
619634
) -> Tuple[Tensor, Dict[str, Tensor]]:
635+
"""The loss for the actor.
636+
637+
Args:
638+
tensordict (TensorDictBase): the input data. See :attr:`~.in_keys` for more details
639+
on the required fields.
640+
641+
Returns: a tensor containing the actor loss along with a dictionary of metadata.
642+
643+
"""
620644
with set_exploration_type(
621645
ExplorationType.RANDOM
622646
), self.actor_network_params.to_module(self.actor_network):
@@ -626,10 +650,12 @@ def _actor_loss(
626650

627651
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
628652
td_q.set(self.tensor_keys.action, a_reparm)
653+
629654
td_q = self._vmap_qnetworkN0(
630655
td_q,
631656
self._cached_detached_qvalue_params, # should we clone?
632657
)
658+
633659
min_q_logprob = (
634660
td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
635661
)

torchrl/objectives/utils.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
import re
99
import warnings
1010
from enum import Enum
11-
from typing import Iterable, Optional, Union
11+
from typing import Iterable, Optional, Tuple, Union
1212

1313
import torch
1414
from tensordict import TensorDict, TensorDictBase
1515
from tensordict.nn import TensorDictModule
16+
from tensordict.utils import _zip_strict
1617
from torch import nn, Tensor
1718
from torch.nn import functional as F
1819
from torch.nn.modules import dropout
20+
from torch.utils._pytree import tree_map
1921

2022
try:
2123
from torch import vmap
@@ -480,7 +482,7 @@ def new_fun(self, netname=None):
480482
return new_fun
481483

482484

483-
def _vmap_func(module, *args, func=None, **kwargs):
485+
def _vmap_func(module, *args, func=None, call_vmap: bool = True, **kwargs):
484486
try:
485487

486488
def decorated_module(*module_args_params):
@@ -503,6 +505,69 @@ def decorated_module(*module_args_params):
503505
) from err
504506

505507

508+
class _LoopVmapModule(nn.Module):
509+
def __init__(
510+
self,
511+
module: nn.Module,
512+
in_dims: Tuple[int | None] = None,
513+
out_dims: Tuple[int | None] = None,
514+
register_module: bool = False,
515+
functional: bool = False,
516+
):
517+
super().__init__()
518+
self.register_module = register_module
519+
if not register_module:
520+
self.__dict__["module"] = module
521+
else:
522+
self.module = module
523+
self.in_dims = in_dims
524+
if out_dims is not None:
525+
raise NotImplementedError("out_dims not implemented yet.")
526+
self.out_dims = out_dims
527+
self.functional = functional
528+
529+
def forward(self, *args):
530+
n = None
531+
to_rep = []
532+
if self.in_dims is None:
533+
self.in_dims = [0] * len(args)
534+
args = list(args)
535+
for i, (arg, in_dim) in enumerate(_zip_strict(args, self.in_dims)):
536+
if in_dim is not None:
537+
arg = arg.unbind(in_dim)
538+
if n is None:
539+
n = len(arg)
540+
elif n != len(arg):
541+
raise ValueError(
542+
f"The length of the unbound args differs: {n} vs {len(arg)}."
543+
)
544+
args[i] = arg
545+
else:
546+
to_rep.append(i)
547+
args = [
548+
tuple(arg.copy() for _ in range(n)) if i in to_rep else arg
549+
for i, arg in enumerate(args)
550+
]
551+
out = []
552+
n_out = None
553+
for _args in zip(*args):
554+
if self.functional:
555+
with _args[-1].to_module(self.module):
556+
out.append(self.module(*_args[:-1]))
557+
else:
558+
out.append(self.module(*_args))
559+
if n_out is None:
560+
n_out = len(out[-1]) if isinstance(out[-1], tuple) else 1
561+
if n_out > 1:
562+
return tree_map(lambda *x: torch.stack(out, dim=0), *out)
563+
elif n_out == 1:
564+
# We explicitly assume that out can be stacked
565+
result = torch.stack(out, dim=0)
566+
return result
567+
else:
568+
raise ValueError("Could not determine the number of outputs.")
569+
570+
506571
def _reduce(tensor: torch.Tensor, reduction: str) -> Union[float, torch.Tensor]:
507572
"""Reduces a tensor given the reduction method."""
508573
if reduction == "none":

0 commit comments

Comments
 (0)