-
Notifications
You must be signed in to change notification settings - Fork 297
Added a VectorizedGymWrapper #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Oh, I really like this change! Is this |
Not directly, but the only big difference is that the Baselines VecEnv use the store just the single observation and action spaces, while the Gym.vector version keeps a batched space (easier to sample from) in addition. It wouldn't be too hard to address that and maybe make separate "create_env" functions for them |
Would you mind adding this |
Also on a related note, it might be tricky to use other DL learners that are not implemented in Jax. I was testing out the gym integration and tried to make it work with PyTorch using GPU, and it had similar (slow) throughput like the pybullet envs. Then, I realized it was probably because the flow is like I think there are some issues tracking this to make the transfer much easier such as jax-ml/jax#1865 and jax-ml/jax#1100, but overall I'm not sure how to resolve this. I think using your vectorized environment interface is going to speed up things, but was wondering if you had done any speed tests. |
I could definitely do that, though I'm not really affiliated with the team, I'm just using Brax as an alternative to Mujoco in my own research and figured it was an easy feature to quickly code up without seriously changing the codebase! As far as speed, this commit is specifically made to alter the behavior from the original gym as little as possible. In my own repo with pytorch, I do the jit with a cpu backend in both step and reset (to avoid the first gpu -> cpu transfer). Still better than Mujoco/Pybullet! Having listened to a lot of the Pytorch Dev podcast, I'm not optimistic for ever seeing a way to directly convert between PyTorch and JAX GPU tensors, there's a lot under the hood :( |
1. StableBaselines3 wrapper, subclassing abstract VecEnv. Very basic at this point (no "render" or abstract method implementations). StableBaselines3 dependency is left optional 2. Made "create_gym_vector_env" and "create_baselines_vec_env" functions in __init__ 3. Added "seed" method to base GymWrapper
I made a very basic implementation of the Baselines wrapper you suggested. I didn't want to add the dependency, so it's wrapped in a try-except, and while many of the methods I didn't implement aren't relevant (e.g., "step_async, step_wait"), there are some ignored methods like render that could be useful. |
brax/envs/wrappers.py
Outdated
|
||
|
||
try: | ||
from stable_baselines3.common.vec_env import VecEnv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel like this belongs here, this should be done in the SB3 repo, in my opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 - also given it's only partially implemented, prefer to leave this out. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice improvement - thanks for submitting it! Just left a host of nits, after which we'll be good to go. Cheers.
brax/envs/__init__.py
Outdated
@@ -57,3 +57,16 @@ def create_fn(env_name: str, **kwargs) -> Callable[..., Env]: | |||
def create_gym_env(env_name: str, **kwargs) -> gym.Env: | |||
"""Creates a Gym Env with a specified brax system.""" | |||
return wrappers.GymWrapper(create(env_name, **kwargs)) | |||
|
|||
|
|||
def create_gym_vector_env(env_name: str, **kwargs) -> gym.vector.VectorEnv: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than create a new function, I think we can just check the batch_size parameter in create_gym_env
and return VecGymWrapper
if it's > 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might I suggest this:
from typing import overload
@overload
def create_env(env_name: str, seed: int = 0, **kwargs) -> GymWrapper:
...
@overload
def create_env(env_name: str, batch_size: int, seed: int = 0, **kwargs) -> VectorGymWrapper:
...
def create_env(
env_name: str, batch_size: int = None, seed: int = 0, **kwargs
) -> Union[GymWrapper, VectorGymWrapper]:
if batch_size is None:
return GymWrapper(create(env_name=env_name, **kwargs), seed=seed)
else:
return VectorGymWrapper(create(env_name=env_name, batch_size=batch_size, **kwargs), seed=seed)
brax/envs/wrappers.py
Outdated
|
||
|
||
try: | ||
from stable_baselines3.common.vec_env import VecEnv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 - also given it's only partially implemented, prefer to leave this out. Thanks!
brax/envs/__init__.py
Outdated
return wrappers.VecGymWrapper(create(env_name, **kwargs)) | ||
|
||
|
||
def create_baselines_vec_env(env_name: str, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid this for now - would rather more fully support this integration after we get more requests to support it, so that we know it's worth it.
@@ -55,3 +56,119 @@ def reset(self): | |||
def step(self, action): | |||
self._state, obs, reward, done = self._step(self._state, action) | |||
return obs, reward, done, {} | |||
|
|||
def seed(self, seed: int = 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this method - please call it from the initializer like you do in the other class.
brax/envs/wrappers.py
Outdated
self._key = jax.random.PRNGKey(seed) | ||
|
||
|
||
class VecGymWrapper(gym.vector.VectorEnv): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefer VectorGymWrapper
just to keep parity with naming conventions.
1. Removed StableBaselines3 VecEnvWrapper 2. Added seed to GymWrapper initialization 3. Renamed VecGymWrapper -> VectorGymWrapper 4. Merged create_gym_env and create_gym_vector_env
I removed the draft work on the Baselines VecEnv and updated in response to the comments! As a side note, nice work on the torch and backend wrappers @lebrice ! I thought about including the backend as an argument here but deferred to your PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thank you!
One of the main advantages of Brax is being able to create many parallel instances of physics environments. The original GymWrapper lets users use single-instance Brax environments in the familiar Gym API. This proposed change provides a vectorized version that follows the Gym VectorEnv API.