Skip to content

Added a VectorizedGymWrapper #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions brax/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Some example environments to help get started quickly with brax."""

import functools
from typing import Callable
from typing import Callable, Union, overload

import gym
import brax
Expand Down Expand Up @@ -54,6 +54,21 @@ def create_fn(env_name: str, **kwargs) -> Callable[..., Env]:
return functools.partial(create, env_name, **kwargs)


def create_gym_env(env_name: str, **kwargs) -> gym.Env:
"""Creates a Gym Env with a specified brax system."""
return wrappers.GymWrapper(create(env_name, **kwargs))
@overload
def create_gym_env(env_name: str, seed: int = 0, **kwargs) -> wrappers.GymWrapper:
...


@overload
def create_gym_env(env_name: str, batch_size: int, seed: int = 0, **kwargs) -> wrappers.VectorGymWrapper:
...


def create_gym_env(
env_name: str, batch_size: int = 0, seed: int = 0, **kwargs
) -> Union[wrappers.GymWrapper, wrappers.VectorGymWrapper]:
if batch_size:
return wrappers.VectorGymWrapper(create(env_name=env_name, batch_size=batch_size, **kwargs), seed=seed)
else:
return wrappers.GymWrapper(create(env_name=env_name, **kwargs), seed=seed)

49 changes: 48 additions & 1 deletion brax/envs/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import gym
from gym import spaces
from gym.vector.utils import batch_space
import jax
import numpy as np
from brax.envs import env
Expand All @@ -26,7 +27,7 @@ class GymWrapper(gym.Env):

def __init__(self, environment: env.Env, seed: int = 0):
self._environment = environment
self._key = jax.random.PRNGKey(seed)
self.seed(seed)

# action_space = None
obs_high = np.inf * np.ones(self._environment.observation_size)
Expand Down Expand Up @@ -55,3 +56,49 @@ def reset(self):
def step(self, action):
self._state, obs, reward, done = self._step(self._state, action)
return obs, reward, done, {}

def seed(self, seed: int = 0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this method - please call it from the initializer like you do in the other class.

self._key = jax.random.PRNGKey(seed)


class VectorGymWrapper(gym.vector.VectorEnv):
"""A wrapper that converts batched Brax Env to one that follows Gym VectorEnv API."""

def __init__(self, environment: env.Env, seed: int = 0):
self._environment = environment
assert self._environment.batch_size # Make sure underlying environment is batched

self.num_envs = self._environment.batch_size
self._key_size = self.num_envs + 1
self.seed(seed)

obs_high = np.inf * np.ones(self._environment.observation_size)
self.single_observation_space = spaces.Box(-obs_high, obs_high, dtype=np.float32)
self.observation_space = batch_space(self.single_observation_space, self.num_envs)

action_high = np.ones(self._environment.action_size)
self.single_action_space = spaces.Box(-action_high, action_high, dtype=np.float32)
self.action_space = batch_space(self.single_action_space, self.num_envs)
self._state = None

def reset(key):
keys = jax.random.split(key, self._key_size)
state = self._environment.reset(keys[1:])
return state, state.obs, keys[0]
self._reset = jax.jit(reset)

def step(state, action):
state = self._environment.step(state, action)
return state, state.obs, state.reward, state.done
self._step = jax.jit(step, backend='cpu')

def reset(self):
self._state, obs, self._key = self._reset(self._key)
return obs

def step(self, action):
self._state, obs, reward, done = self._step(self._state, action)
return obs, reward, done, {}

def seed(self, seed: int = 0):
self._key = jax.random.PRNGKey(seed)