Skip to content

[Feature] Add Choice spec #2713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 128 additions & 1 deletion test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
import torch
import torchrl.data.tensor_specs
from scipy.stats import chisquare
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict import (
LazyStackedTensorDict,
NonTensorData,
NonTensorStack,
TensorDict,
TensorDictBase,
)
from tensordict.utils import _unravel_key_to_tuple
from torchrl._utils import _make_ordinal_device

Expand All @@ -23,6 +29,7 @@
Bounded,
BoundedTensorSpec,
Categorical,
Choice,
Composite,
CompositeSpec,
ContinuousBox,
Expand Down Expand Up @@ -678,6 +685,23 @@ def test_change_batch_size(self, shape, is_complete, device, dtype):
assert ts["nested"].shape == (3,)


class TestChoiceSpec:
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
def test_choice(self, input_type):
if input_type == "spec":
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
elif input_type == "nontensor":
stack = torch.stack([NonTensorData("a"), NonTensorData("b")])
elif input_type == "nontensorstack":
stack = torch.stack(
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
)

spec = Choice(stack)
res = spec.rand()
assert spec.is_in(res)


@pytest.mark.parametrize("shape", [(), (2, 3)])
@pytest.mark.parametrize("device", get_default_devices())
def test_create_composite_nested(shape, device):
Expand Down Expand Up @@ -1409,6 +1433,21 @@ def test_non_tensor(self):
== NonTensor((2, 3, 4), device="cpu")
)

@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
def test_choice(self, input_type):
if input_type == "spec":
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
elif input_type == "nontensor":
stack = torch.stack([NonTensorData("a"), NonTensorData("b")])
elif input_type == "nontensorstack":
stack = torch.stack(
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
)

spec = Choice(stack)
with pytest.raises(NotImplementedError):
spec.expand((3,))

@pytest.mark.parametrize("shape1", [None, (), (5,)])
@pytest.mark.parametrize("shape2", [(), (10,)])
def test_onehot(self, shape1, shape2):
Expand Down Expand Up @@ -1611,6 +1650,21 @@ def test_non_tensor(self):
assert spec.clone() == spec
assert spec.clone() is not spec

@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
def test_choice(self, input_type):
if input_type == "spec":
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
elif input_type == "nontensor":
stack = torch.stack([NonTensorData("a"), NonTensorData("b")])
elif input_type == "nontensorstack":
stack = torch.stack(
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
)

spec = Choice(stack)
assert spec.clone() == spec
assert spec.clone() is not spec

@pytest.mark.parametrize("shape1", [None, (), (5,)])
def test_onehot(
self,
Expand Down Expand Up @@ -1696,6 +1750,35 @@ def test_non_tensor(self):
with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."):
spec.cardinality()

@pytest.mark.parametrize(
"input_type",
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
)
def test_choice(self, input_type):
if input_type == "bounded_spec":
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
elif input_type == "categorical_spec":
stack = torch.stack([Categorical(10), Categorical(20)])
elif input_type == "nontensor":
stack = torch.stack(
[NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
)
elif input_type == "nontensorstack":
stack = torch.stack(
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
)

spec = Choice(stack)

if input_type == "bounded_spec":
assert spec.cardinality() == float("inf")
elif input_type == "categorical_spec":
assert spec.cardinality() == 30
elif input_type == "nontensor":
assert spec.cardinality() == 3
elif input_type == "nontensorstack":
assert spec.cardinality() == 2

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(
self,
Expand Down Expand Up @@ -2004,6 +2087,27 @@ def test_non_tensor(self, device):
spec = NonTensor(shape=(3, 4), device="cpu")
assert spec.to(device).device == device

@pytest.mark.parametrize(
"input_type",
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
)
def test_choice(self, input_type, device):
if input_type == "bounded_spec":
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
elif input_type == "categorical_spec":
stack = torch.stack([Categorical(10), Categorical(20)])
elif input_type == "nontensor":
stack = torch.stack(
[NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
)
elif input_type == "nontensorstack":
stack = torch.stack(
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
)

spec = Choice(stack, device="cpu")
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(self, shape1, device):
if shape1 is None:
Expand Down Expand Up @@ -2270,6 +2374,29 @@ def test_stack_non_tensor(self, shape, stack_dim):
assert new_spec.shape == torch.Size(shape_insert)
assert new_spec.device == torch.device("cpu")

@pytest.mark.parametrize(
"input_type",
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
)
def test_stack_choice(self, input_type, shape, stack_dim):
if input_type == "bounded_spec":
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
elif input_type == "categorical_spec":
stack = torch.stack([Categorical(10), Categorical(20)])
elif input_type == "nontensor":
stack = torch.stack(
[NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
)
elif input_type == "nontensorstack":
stack = torch.stack(
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
)

spec0 = Choice(stack)
spec1 = Choice(stack)
with pytest.raises(NotImplementedError):
torch.stack([spec0, spec1], 0)

def test_stack_onehot(self, shape, stack_dim):
n = 5
shape = (*shape, 5)
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
BoundedContinuous,
BoundedTensorSpec,
Categorical,
Choice,
Composite,
CompositeSpec,
DEVICE_TYPING,
Expand Down
94 changes: 94 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_tensor_collection,
LazyStackedTensorDict,
NonTensorData,
NonTensorStack,
TensorDict,
TensorDictBase,
unravel_key,
Expand Down Expand Up @@ -3678,6 +3679,99 @@ def clone(self) -> Categorical:
)


class Choice(TensorSpec):
"""A discrete choice spec for either tensor or non-tensor data.

Args:
stack (:class:`~Stacked`, :class:`~StackedComposite`, or :class:`~tensordict.NonTensorStack`):
Stack of specs or non-tensor data from which to choose during
sampling.
device (str, int or torch.device, optional): device of the tensors.

Examples:
>>> import torch
>>> _ = torch.manual_seed(0)
>>> from torchrl.data import Choice, Categorical
>>> spec = Choice(torch.stack([
... Categorical(n=4, shape=(1,)),
... Categorical(n=4, shape=(2,))]))
>>> spec.shape
torch.Size([2, -1])
>>> spec.rand()
tensor([3])
>>> spec.rand()
tensor([0, 3])
"""

def __init__(
self,
stack: Stacked | StackedComposite | NonTensorStack,
device: Optional[DEVICE_TYPING] = None,
):
assert isinstance(stack, (Stacked, StackedComposite, NonTensorStack))
stack = stack.clone()
if device is not None:
self._stack = stack.to(device)
else:
self._stack = stack
device = stack.device

shape = stack.shape
dtype = stack.dtype

domain = None
super().__init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain
)

def _rand_idx(self):
return torch.randint(0, len(self._stack), ()).item()

def _sample(self, idx, spec_sample_fn) -> TensorDictBase:
res = self._stack[idx]
if isinstance(res, TensorSpec):
return spec_sample_fn(res)
else:
return res

def zero(self, shape: torch.Size = None) -> TensorDictBase:
return self._sample(0, lambda x: x.zero(shape))

def one(self, shape: torch.Size = None) -> TensorDictBase:
return self._sample(min(1, len(self - 1)), lambda x: x.one(shape))

def rand(self, shape: torch.Size = None) -> TensorDictBase:
return self._sample(self._rand_idx(), lambda x: x.rand(shape))

def is_in(self, val: torch.Tensor | TensorDictBase) -> bool:
if isinstance(self._stack, (Stacked, StackedComposite)):
return any([stack_elem.is_in(val) for stack_elem in self._stack])
else:
return any([(stack_elem == val).all() for stack_elem in self._stack])

def expand(self, *shape):
raise NotImplementedError

def unsqueeze(self, dim: int):
raise NotImplementedError

def clone(self) -> Choice:
return self.__class__(self._stack)

def cardinality(self) -> int:
if isinstance(self._stack, NonTensorStack):
return len(self._stack)
else:
return (
torch.tensor([stack_elem.cardinality() for stack_elem in self._stack])
.sum()
.item()
)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Choice:
return self.__class__(self._stack.to(dest))


@dataclass(repr=False)
class Binary(Categorical):
"""A binary discrete tensor spec.
Expand Down
Loading