-
Notifications
You must be signed in to change notification settings - Fork 7.3k
[RLlib] Add example: Pre-train an RLModule single-agent, then bring checkpoint into multi-agent setup and continue training.
#44674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
sven1977
merged 41 commits into
ray-project:master
from
simonsays1980:rl-module-pre-training-example
Apr 16, 2024
Merged
Changes from all commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
5a85f24
Defined new properties '_model_auto_keys' and 'model_config_dict' to …
simonsays1980 0c2dbc9
Added new properties to 'PPO' and 'BC' in new stack and modified tune…
simonsays1980 aa84898
Added new property to SAC in new stack and modified tuned example acc…
simonsays1980 56f4cf6
Fixed multiple tests and examples to incorporate the new 'model_confi…
simonsays1980 9e50e87
Reran multiple tests as some were failing in CI. Error in VisionNet r…
simonsays1980 70067f9
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 0e56410
Changed example settings as many example runs were failing in CI tests.
simonsays1980 904ca01
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 f50c0f5
Included @sven1977's review.
simonsays1980 2055087
Quick fix as tests were failing because hard deprecation of '_enable_…
simonsays1980 8b7fbfd
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 61bb265
Adapted some connector examples to using the new model config.
simonsays1980 32cf531
Fixed a minor bug in PPO model config.
simonsays1980 3f365e0
Removed bug in SingleAgentEnvRunner using the old model_config_dict i…
simonsays1980 0d1635e
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 84269c4
Adjusted a couple of examples to use the 'model_config_dict' in 'rl_m…
simonsays1980 7d92725
Fixed some minor bugs in CI tests.
simonsays1980 ff8c29a
Fixed bug in docs.
simonsays1980 6f1252f
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 74017bd
Merging master and linting.
simonsays1980 9fc4f86
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 5727729
Fixed remaining bugs in test and docs caused by external RLModule's n…
simonsays1980 b2f92f2
Moved example to another branch.
simonsays1980 21cea3b
Added a pre-training example for RLModules in a MARL setting. Pre-tra…
simonsays1980 9f595f7
Fixed failing test in 'rllib/tests'. The test was failing because the…
simonsays1980 f5d37c2
Fixed failing test code in docs.
simonsays1980 57db62b
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 f465a3b
Merge branch 'master' of https://github.com/ray-project/ray into mode…
sven1977 98d46b1
fixes
sven1977 dd590fd
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 46f5f5a
Fixed another bug in example file.
simonsays1980 88bd055
Merge branch 'model-config-for-new-api-stack' of github.com:simonsays…
simonsays1980 c01438d
Merge branch 'model-config-for-new-api-stack' into rl-module-pre-trai…
simonsays1980 891cbf2
Merge branch 'master' into rl-module-pre-training-example
sven1977 be72c17
Apply suggestions from code review
sven1977 25846c8
Added review from @sven1977 and fixed a minor bug with the number of …
simonsays1980 724b580
Added example to the BUILD file.
simonsays1980 5ccb496
Changed paths in BUILD file.
simonsays1980 abd565d
Added args to BUILD.
simonsays1980 72f1be1
Added more args and a comma.
simonsays1980 e9b4af1
Apply suggestions from code review
sven1977 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
149 changes: 149 additions & 0 deletions
149
rllib/examples/rl_modules/pretraining_single_agent_training_multi_agent_rlm.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| """Example of running a single-agent pre-training followed with a multi-agent training. | ||
|
|
||
| This examples `num_agents` agents each of them with its own `RLModule` that defines its | ||
| policy. The first agent is pre-trained using a single-agent PPO algorithm. All agents | ||
| are trained together in the main training run using a multi-agent PPO algorithm where | ||
| the pre-trained module is used for the first agent. | ||
|
|
||
| The environment is MultiAgentCartPole, in which there are n agents both policies. | ||
|
|
||
| How to run this script | ||
| ---------------------- | ||
| `python [script file name].py --enable-new-api-stack --num-agents=2` | ||
|
|
||
| For debugging, use the following additional command line options | ||
| `--no-tune --num-env-runners=0` | ||
| which should allow you to set breakpoints anywhere in the RLlib code and | ||
| have the execution stop there for inspection and debugging. | ||
|
|
||
| For logging to your WandB account, use: | ||
| `--wandb-key=[your WandB API key] --wandb-project=[some project name] | ||
| --wandb-run-name=[optional: WandB run name (within the defined project)]` | ||
|
|
||
|
|
||
|
|
||
| """ | ||
|
|
||
| import gymnasium as gym | ||
| from ray.rllib.algorithms.ppo import PPOConfig | ||
| from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog | ||
| from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule | ||
| from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec | ||
| from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec | ||
| from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole | ||
| from ray.rllib.utils.test_utils import ( | ||
| add_rllib_example_script_args, | ||
| run_rllib_example_script_experiment, | ||
| ) | ||
| from ray.tune import register_env | ||
|
|
||
| # Read in common example script command line arguments. | ||
| parser = add_rllib_example_script_args( | ||
| # Use less training steps for the main training run. | ||
| default_timesteps=50000, | ||
| default_reward=200.0, | ||
| default_iters=20, | ||
| ) | ||
| # Instead use mroe for the pre-training run. | ||
| parser.add_argument( | ||
| "--stop-iters-pretraining", | ||
| type=int, | ||
| default=200, | ||
| help="The number of iterations to pre-train.", | ||
| ) | ||
| parser.add_argument( | ||
| "--stop-timesteps-pretraining", | ||
| type=int, | ||
| default=5000000, | ||
| help="The number of (environment sampling) timesteps to pre-train.", | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| # Parse the command line arguments. | ||
| args = parser.parse_args() | ||
|
|
||
| # Ensure that the user has set the number of agents. | ||
| if args.num_agents == 0: | ||
| raise ValueError( | ||
| "This pre-training example script requires at least 1 agent. " | ||
| "Try setting the command line argument `--num-agents` to the " | ||
| "number of agents you want to use." | ||
| ) | ||
|
|
||
| # Store the user's stopping criteria for the later training run. | ||
| stop_iters = args.stop_iters | ||
| stop_timesteps = args.stop_timesteps | ||
| checkpoint_at_end = args.checkpoint_at_end | ||
| num_agents = args.num_agents | ||
| # Override these criteria for the pre-training run. | ||
| setattr(args, "stop_iters", args.stop_iters_pretraining) | ||
| setattr(args, "stop_timesteps", args.stop_timesteps_pretraining) | ||
| setattr(args, "checkpoint_at_end", True) | ||
| setattr(args, "num_agents", 0) | ||
|
|
||
| # Define out pre-training single-agent algorithm. We will use the same module | ||
| # configuration for the pre-training and the training. | ||
| config = ( | ||
| PPOConfig() | ||
| .environment("CartPole-v1") | ||
| .rl_module( | ||
| # Use a different number of hidden units for the pre-trained module. | ||
| model_config_dict={"fcnet_hiddens": [64]}, | ||
| ) | ||
| ) | ||
|
|
||
| # Run the pre-training. | ||
| results = run_rllib_example_script_experiment(config, args) | ||
| # Get the checkpoint path. | ||
| module_chkpt_path = results.get_best_result().checkpoint.path | ||
|
|
||
| # Create a new MARL Module using the pre-trained module for policy 0. | ||
| env = gym.make("CartPole-v1") | ||
| module_specs = {} | ||
| module_class = PPOTorchRLModule | ||
| for i in range(args.num_agents): | ||
| module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec( | ||
| module_class=PPOTorchRLModule, | ||
| observation_space=env.observation_space, | ||
| action_space=env.action_space, | ||
| model_config_dict={"fcnet_hiddens": [32]}, | ||
| catalog_class=PPOCatalog, | ||
| ) | ||
|
|
||
| # Swap in the pre-trained module for policy 0. | ||
| module_specs["policy_0"] = SingleAgentRLModuleSpec( | ||
| module_class=PPOTorchRLModule, | ||
| observation_space=env.observation_space, | ||
| action_space=env.action_space, | ||
| model_config_dict={"fcnet_hiddens": [64]}, | ||
| catalog_class=PPOCatalog, | ||
| # Note, we load here the module directly from the checkpoint. | ||
| load_state_path=module_chkpt_path, | ||
| ) | ||
| marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) | ||
|
|
||
| # Register our environment with tune if we use multiple agents. | ||
| register_env( | ||
| "multi-agent-carpole-env", | ||
| lambda _: MultiAgentCartPole(config={"num_agents": args.num_agents}), | ||
| ) | ||
|
|
||
| # Configure the main (multi-agent) training run. | ||
| config = ( | ||
| PPOConfig() | ||
| .environment( | ||
| "multi-agent-carpole-env" if args.num_agents > 0 else "CartPole-v1" | ||
| ) | ||
| .rl_module(rl_module_spec=marl_module_spec) | ||
| ) | ||
|
|
||
| # Restore the user's stopping criteria for the training run. | ||
| setattr(args, "stop_iters", stop_iters) | ||
| setattr(args, "stop_timesteps", stop_timesteps) | ||
| setattr(args, "checkpoint_at_end", checkpoint_at_end) | ||
| setattr(args, "num_agents", num_agents) | ||
|
|
||
| # Run the main training run. | ||
| run_rllib_example_script_experiment(config, args) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!