Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion acme/adders/reverb/episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,26 @@


class EpisodeAdder(base.ReverbAdder):
"""Adder which adds entire episodes as trajectories."""
"""Adder which adds entire episodes as trajectories.

This adder accumulates all steps of an episode and inserts them as a single
trajectory item into Reverb at the end of the episode. It is useful for
algorithms that require full episodes (e.g., for offline learning or MCTS).

Args:
client: The Reverb client to use for data insertion.
max_sequence_length: The maximum length of an episode. Episodes longer
than this will raise a ValueError. If padding_fn is provided, episodes
shorter than this will be padded to this length.
delta_encoded: Whether to use delta encoding for the trajectory.
priority_fns: A mapping from table names to priority functions.
max_in_flight_items: The maximum number of items allowed to be in flight
(being sent to Reverb) at the same time.
padding_fn: An optional callable that takes a shape and dtype and returns
a zero-filled (or otherwise equivalent 'empty') array of that shape and
dtype. If provided, episodes shorter than max_sequence_length will be
padded.
"""

def __init__(
self,
Expand Down
20 changes: 5 additions & 15 deletions acme/tf/networks/legal_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,8 @@ def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor:
return outputs


# FIXME: Add functionality to support decaying epsilon parameter.
# FIXME: This is a modified version of trfl's epsilon_greedy() which
# incorporates code from the bug fix described here
# https://github.com/deepmind/trfl/pull/28
class EpsilonGreedy(snt.Module):
"""Computes an epsilon-greedy distribution over actions.

This policy does the following:
- With probability 1 - epsilon, take the action corresponding to the highest
action value, breaking ties uniformly at random.
- With probability epsilon, take an action uniformly at random.
"""

def __init__(self,
epsilon: Union[tf.Tensor, float],
epsilon: Union[tf.Tensor, float, tf.Variable],
threshold: float,
name: str = 'EpsilonGreedy'):
"""Initialize the policy.
Expand All @@ -95,7 +82,10 @@ def __init__(self,
policy.
"""
super().__init__(name=name)
self._epsilon = tf.Variable(epsilon, trainable=False)
if isinstance(epsilon, tf.Variable):
self._epsilon = epsilon
else:
self._epsilon = tf.Variable(epsilon, trainable=False)
self._threshold = threshold

def __call__(self, action_values: tf.Tensor) -> tfd.Categorical:
Expand Down
18 changes: 9 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@
# sure this constraint is upheld.

tensorflow = [
'tensorflow==2.8.0',
'tensorflow_probability==0.15.0',
'tensorflow_datasets==4.6.0',
'dm-reverb==0.7.2',
'dm-launchpad==0.5.2',
'tensorflow>=2.8.0',
'tensorflow_probability>=0.15.0',
'tensorflow_datasets>=4.6.0',
'dm-reverb>=0.7.2',
'dm-launchpad>=0.5.2',
]

core_requirements = [
Expand All @@ -54,8 +54,8 @@
]

jax_requirements = [
'jax==0.4.3',
'jaxlib==0.4.3',
'jax>=0.4.3',
'jaxlib>=0.4.3',
'chex',
'dm-haiku',
'flax',
Expand All @@ -77,9 +77,9 @@
'atari-py',
'bsuite',
'dm-control',
'gym==0.25.0',
'gym>=0.25.0,<0.26.0',
'gym[atari]',
'pygame==2.1.0',
'pygame>=2.1.0',
'rlds',
]

Expand Down