Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5a85f24
Defined new properties '_model_auto_keys' and 'model_config_dict' to …
simonsays1980 Mar 25, 2024
0c2dbc9
Added new properties to 'PPO' and 'BC' in new stack and modified tune…
simonsays1980 Mar 25, 2024
aa84898
Added new property to SAC in new stack and modified tuned example acc…
simonsays1980 Mar 25, 2024
56f4cf6
Fixed multiple tests and examples to incorporate the new 'model_confi…
simonsays1980 Mar 26, 2024
9e50e87
Reran multiple tests as some were failing in CI. Error in VisionNet r…
simonsays1980 Mar 26, 2024
70067f9
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 Mar 26, 2024
0e56410
Changed example settings as many example runs were failing in CI tests.
simonsays1980 Mar 27, 2024
904ca01
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 Mar 27, 2024
f50c0f5
Included @sven1977's review.
simonsays1980 Apr 2, 2024
2055087
Quick fix as tests were failing because hard deprecation of '_enable_…
simonsays1980 Apr 2, 2024
8b7fbfd
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 Apr 8, 2024
61bb265
Adapted some connector examples to using the new model config.
simonsays1980 Apr 8, 2024
32cf531
Fixed a minor bug in PPO model config.
simonsays1980 Apr 8, 2024
3f365e0
Removed bug in SingleAgentEnvRunner using the old model_config_dict i…
simonsays1980 Apr 8, 2024
0d1635e
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 Apr 9, 2024
84269c4
Adjusted a couple of examples to use the 'model_config_dict' in 'rl_m…
simonsays1980 Apr 9, 2024
7d92725
Fixed some minor bugs in CI tests.
simonsays1980 Apr 9, 2024
ff8c29a
Fixed bug in docs.
simonsays1980 Apr 9, 2024
6f1252f
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 Apr 10, 2024
74017bd
Merging master and linting.
simonsays1980 Apr 10, 2024
9fc4f86
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 Apr 11, 2024
5727729
Fixed remaining bugs in test and docs caused by external RLModule's n…
simonsays1980 Apr 11, 2024
b2f92f2
Moved example to another branch.
simonsays1980 Apr 11, 2024
21cea3b
Added a pre-training example for RLModules in a MARL setting. Pre-tra…
simonsays1980 Apr 11, 2024
9f595f7
Fixed failing test in 'rllib/tests'. The test was failing because the…
simonsays1980 Apr 11, 2024
f5d37c2
Fixed failing test code in docs.
simonsays1980 Apr 11, 2024
57db62b
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 Apr 11, 2024
f465a3b
Merge branch 'master' of https://github.com/ray-project/ray into mode…
sven1977 Apr 12, 2024
98d46b1
fixes
sven1977 Apr 12, 2024
dd590fd
Merge branch 'master' into model-config-for-new-api-stack
simonsays1980 Apr 12, 2024
46f5f5a
Fixed another bug in example file.
simonsays1980 Apr 12, 2024
88bd055
Merge branch 'model-config-for-new-api-stack' of github.com:simonsays…
simonsays1980 Apr 12, 2024
c01438d
Merge branch 'model-config-for-new-api-stack' into rl-module-pre-trai…
simonsays1980 Apr 12, 2024
891cbf2
Merge branch 'master' into rl-module-pre-training-example
sven1977 Apr 15, 2024
be72c17
Apply suggestions from code review
sven1977 Apr 15, 2024
25846c8
Added review from @sven1977 and fixed a minor bug with the number of …
simonsays1980 Apr 15, 2024
724b580
Added example to the BUILD file.
simonsays1980 Apr 15, 2024
5ccb496
Changed paths in BUILD file.
simonsays1980 Apr 15, 2024
abd565d
Added args to BUILD.
simonsays1980 Apr 16, 2024
72f1be1
Added more args and a comma.
simonsays1980 Apr 16, 2024
e9b4af1
Apply suggestions from code review
sven1977 Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2873,7 +2873,14 @@ py_test(
size = "small",
srcs = ["examples/rl_modules/classes/mobilenet_rlm.py"],
)

py_test(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

name = "examples/rl_modules/pretraining_single_agent_training_multi_agent_rlm",
main = "examples/rl_modules/pretraining_single_agent_training_multi_agent_rlm.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/rl_modules/pretraining_single_agent_training_multi_agent_rlm.py"],
args = ["--enable-new-api-stack", "--num-agents=2", "--stop-iters-pretraining=5", "--stop-iters=20", "--stop-reward=150.0"],
)



Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Example of running a single-agent pre-training followed with a multi-agent training.

This examples `num_agents` agents each of them with its own `RLModule` that defines its
policy. The first agent is pre-trained using a single-agent PPO algorithm. All agents
are trained together in the main training run using a multi-agent PPO algorithm where
the pre-trained module is used for the first agent.

The environment is MultiAgentCartPole, in which there are n agents both policies.

How to run this script
----------------------
`python [script file name].py --enable-new-api-stack --num-agents=2`

For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.

For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`



"""

import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)
from ray.tune import register_env

# Read in common example script command line arguments.
parser = add_rllib_example_script_args(
# Use less training steps for the main training run.
default_timesteps=50000,
default_reward=200.0,
default_iters=20,
)
# Instead use mroe for the pre-training run.
parser.add_argument(
"--stop-iters-pretraining",
type=int,
default=200,
help="The number of iterations to pre-train.",
)
parser.add_argument(
"--stop-timesteps-pretraining",
type=int,
default=5000000,
help="The number of (environment sampling) timesteps to pre-train.",
)


if __name__ == "__main__":

# Parse the command line arguments.
args = parser.parse_args()

# Ensure that the user has set the number of agents.
if args.num_agents == 0:
raise ValueError(
"This pre-training example script requires at least 1 agent. "
"Try setting the command line argument `--num-agents` to the "
"number of agents you want to use."
)

# Store the user's stopping criteria for the later training run.
stop_iters = args.stop_iters
stop_timesteps = args.stop_timesteps
checkpoint_at_end = args.checkpoint_at_end
num_agents = args.num_agents
# Override these criteria for the pre-training run.
setattr(args, "stop_iters", args.stop_iters_pretraining)
setattr(args, "stop_timesteps", args.stop_timesteps_pretraining)
setattr(args, "checkpoint_at_end", True)
setattr(args, "num_agents", 0)

# Define out pre-training single-agent algorithm. We will use the same module
# configuration for the pre-training and the training.
config = (
PPOConfig()
.environment("CartPole-v1")
.rl_module(
# Use a different number of hidden units for the pre-trained module.
model_config_dict={"fcnet_hiddens": [64]},
)
)

# Run the pre-training.
results = run_rllib_example_script_experiment(config, args)
# Get the checkpoint path.
module_chkpt_path = results.get_best_result().checkpoint.path

# Create a new MARL Module using the pre-trained module for policy 0.
env = gym.make("CartPole-v1")
module_specs = {}
module_class = PPOTorchRLModule
for i in range(args.num_agents):
module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec(
module_class=PPOTorchRLModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config_dict={"fcnet_hiddens": [32]},
catalog_class=PPOCatalog,
)

# Swap in the pre-trained module for policy 0.
module_specs["policy_0"] = SingleAgentRLModuleSpec(
module_class=PPOTorchRLModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config_dict={"fcnet_hiddens": [64]},
catalog_class=PPOCatalog,
# Note, we load here the module directly from the checkpoint.
load_state_path=module_chkpt_path,
)
marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs)

# Register our environment with tune if we use multiple agents.
register_env(
"multi-agent-carpole-env",
lambda _: MultiAgentCartPole(config={"num_agents": args.num_agents}),
)

# Configure the main (multi-agent) training run.
config = (
PPOConfig()
.environment(
"multi-agent-carpole-env" if args.num_agents > 0 else "CartPole-v1"
)
.rl_module(rl_module_spec=marl_module_spec)
)

# Restore the user's stopping criteria for the training run.
setattr(args, "stop_iters", stop_iters)
setattr(args, "stop_timesteps", stop_timesteps)
setattr(args, "checkpoint_at_end", checkpoint_at_end)
setattr(args, "num_agents", num_agents)

# Run the main training run.
run_rllib_example_script_experiment(config, args)