Skip to content

Added a VectorizedGymWrapper #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 2, 2021
Merged

Added a VectorizedGymWrapper #20

merged 3 commits into from
Aug 2, 2021

Conversation

DavidSlayback
Copy link
Contributor

One of the main advantages of Brax is being able to create many parallel instances of physics environments. The original GymWrapper lets users use single-instance Brax environments in the familiar Gym API. This proposed change provides a vectorized version that follows the Gym VectorEnv API.

@google-cla google-cla bot added the cla: yes label Jul 26, 2021
@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jul 27, 2021

Oh, I really like this change! Is this VectorizedGymWrapper is compatible with openai/baselines styled vectorized env? Are there any libraries that are built to use the VectorizedGymWrapper API?

@DavidSlayback
Copy link
Contributor Author

Oh, I really like this change! Is this VectorizedGymWrapper is compatible with openai/baselines styled vectorized env? Are there any libraries that are built to use the VectorizedGymWrapper API?

Not directly, but the only big difference is that the Baselines VecEnv use the store just the single observation and action spaces, while the Gym.vector version keeps a batched space (easier to sample from) in addition. It wouldn't be too hard to address that and maybe make separate "create_env" functions for them

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jul 27, 2021

Would you mind adding this Baselines VecEnv adapter? This will make it much easier to use the environments with libraries such as Stable-baselines 3. I'm also happy to build an integration to display the brax HTML visualization in the Weights and Biases dashboard (e.g. like this), allowing quicker insights to see the agent's behavior in different stages of the training.

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jul 27, 2021

Also on a related note, it might be tricky to use other DL learners that are not implemented in Jax.

I was testing out the gym integration and tried to make it work with PyTorch using GPU, and it had similar (slow) throughput like the pybullet envs. Then, I realized it was probably because the flow is like jax gpu arrays -> jax cpu arrays -> pytorch's gpu tensors, which negates the speed advantages of brax.

I think there are some issues tracking this to make the transfer much easier such as jax-ml/jax#1865 and jax-ml/jax#1100, but overall I'm not sure how to resolve this.

I think using your vectorized environment interface is going to speed up things, but was wondering if you had done any speed tests.

@DavidSlayback
Copy link
Contributor Author

I could definitely do that, though I'm not really affiliated with the team, I'm just using Brax as an alternative to Mujoco in my own research and figured it was an easy feature to quickly code up without seriously changing the codebase!

As far as speed, this commit is specifically made to alter the behavior from the original gym as little as possible. In my own repo with pytorch, I do the jit with a cpu backend in both step and reset (to avoid the first gpu -> cpu transfer). Still better than Mujoco/Pybullet!

Having listened to a lot of the Pytorch Dev podcast, I'm not optimistic for ever seeing a way to directly convert between PyTorch and JAX GPU tensors, there's a lot under the hood :(

1. StableBaselines3 wrapper, subclassing abstract VecEnv. Very basic at this point (no "render" or abstract method implementations). StableBaselines3 dependency is left optional

2. Made "create_gym_vector_env" and "create_baselines_vec_env" functions in __init__

3. Added "seed" method to base GymWrapper
@DavidSlayback
Copy link
Contributor Author

I made a very basic implementation of the Baselines wrapper you suggested. I didn't want to add the dependency, so it's wrapped in a try-except, and while many of the methods I didn't implement aren't relevant (e.g., "step_async, step_wait"), there are some ignored methods like render that could be useful.



try:
from stable_baselines3.common.vec_env import VecEnv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel like this belongs here, this should be done in the SB3 repo, in my opinion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 - also given it's only partially implemented, prefer to leave this out. Thanks!

Copy link
Collaborator

@erikfrey erikfrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice improvement - thanks for submitting it! Just left a host of nits, after which we'll be good to go. Cheers.

@@ -57,3 +57,16 @@ def create_fn(env_name: str, **kwargs) -> Callable[..., Env]:
def create_gym_env(env_name: str, **kwargs) -> gym.Env:
"""Creates a Gym Env with a specified brax system."""
return wrappers.GymWrapper(create(env_name, **kwargs))


def create_gym_vector_env(env_name: str, **kwargs) -> gym.vector.VectorEnv:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than create a new function, I think we can just check the batch_size parameter in create_gym_env and return VecGymWrapper if it's > 0

Copy link
Contributor

@lebrice lebrice Aug 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might I suggest this:

from typing import overload

@overload
def create_env(env_name: str, seed: int = 0, **kwargs) -> GymWrapper:
    ...

@overload
def create_env(env_name: str, batch_size: int, seed: int = 0, **kwargs) -> VectorGymWrapper:
    ...

def create_env(
    env_name: str, batch_size: int = None, seed: int = 0, **kwargs
) -> Union[GymWrapper, VectorGymWrapper]:
    if batch_size is None:
        return GymWrapper(create(env_name=env_name, **kwargs), seed=seed) 
    else:
        return VectorGymWrapper(create(env_name=env_name, batch_size=batch_size, **kwargs), seed=seed)



try:
from stable_baselines3.common.vec_env import VecEnv
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 - also given it's only partially implemented, prefer to leave this out. Thanks!

return wrappers.VecGymWrapper(create(env_name, **kwargs))


def create_baselines_vec_env(env_name: str, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid this for now - would rather more fully support this integration after we get more requests to support it, so that we know it's worth it.

@@ -55,3 +56,119 @@ def reset(self):
def step(self, action):
self._state, obs, reward, done = self._step(self._state, action)
return obs, reward, done, {}

def seed(self, seed: int = 0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this method - please call it from the initializer like you do in the other class.

self._key = jax.random.PRNGKey(seed)


class VecGymWrapper(gym.vector.VectorEnv):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer VectorGymWrapper just to keep parity with naming conventions.

1. Removed StableBaselines3 VecEnvWrapper
2. Added seed to GymWrapper initialization
3. Renamed VecGymWrapper -> VectorGymWrapper
4. Merged create_gym_env and create_gym_vector_env
@DavidSlayback
Copy link
Contributor Author

I removed the draft work on the Baselines VecEnv and updated in response to the comments! As a side note, nice work on the torch and backend wrappers @lebrice ! I thought about including the backend as an argument here but deferred to your PR.

@DavidSlayback DavidSlayback requested a review from erikfrey August 2, 2021 21:21
Copy link
Collaborator

@erikfrey erikfrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thank you!

@erikfrey erikfrey merged commit 6ed6e8f into google:main Aug 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants