Skip to content

Commit 8c18695

Browse files
committed
Behavior Cloning
1 parent 344fff2 commit 8c18695

File tree

6 files changed

+87
-160
lines changed

6 files changed

+87
-160
lines changed

python/ray/rllib/bc/bc_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import ray
99
from ray.rllib.bc.experience_dataset import ExperienceDataset
10-
from ray.rllib.bc.policy import Policy
10+
from ray.rllib.bc.policy import BCPolicy
1111
from ray.rllib.models import ModelCatalog
1212
from ray.rllib.optimizers import Evaluator
1313

@@ -17,7 +17,7 @@ def __init__(self, registry, env_creator, config, logdir):
1717
env = ModelCatalog.get_preprocessor_as_wrapper(registry, env_creator(config["env_config"]), config["model"])
1818
self.dataset = ExperienceDataset(config["dataset_path"])
1919
# TODO(rliaw): should change this to be just env.observation_space
20-
self.policy = Policy(registry, env.observation_space.shape, env.action_space, config)
20+
self.policy = BCPolicy(registry, env.observation_space.shape, env.action_space, config)
2121
self.config = config
2222
self.logdir = logdir
2323
self.metrics_queue = queue.Queue()

python/ray/rllib/bc/experience_dataset.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@
1010

1111
class ExperienceDataset(object):
1212
def __init__(self, dataset_path):
13+
"""Create dataset of experience to imitate.
14+
15+
Parameters
16+
----------
17+
dataset_path:
18+
Path of file containing the database as pickled list of trajectories,
19+
each trajectory being a list of steps,
20+
each step containing the observation and action as its first two elements.
21+
The file must be available on each machine used by a BCEvaluator.
22+
"""
1323
self._dataset = list(itertools.chain.from_iterable(pickle.load(open(dataset_path, "rb"))))
1424

1525
def sample(self, batch_size):

python/ray/rllib/bc/policy.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,29 @@
22
from __future__ import division
33
from __future__ import print_function
44

5+
import ray
56
import tensorflow as tf
6-
from ray.rllib.bc.tfpolicy import TFPolicy
7+
from ray.rllib.a3c.policy import Policy
78
from ray.rllib.models.catalog import ModelCatalog
89

910

10-
class Policy(TFPolicy):
11+
class BCPolicy(Policy):
12+
def __init__(self, registry, ob_space, action_space, config, name="local", summarize=True):
13+
super(BCPolicy, self).__init__(ob_space, action_space, name, summarize)
14+
self.registry = registry
15+
self.local_steps = 0
16+
self.config = config
17+
self.summarize = summarize
18+
worker_device = "/job:localhost/replica:0/task:0/cpu:0"
19+
self.g = tf.Graph()
20+
with self.g.as_default(), tf.device(worker_device):
21+
with tf.variable_scope(name):
22+
self._setup_graph(ob_space, action_space)
23+
print("Setting up loss")
24+
self.setup_loss(action_space)
25+
self.setup_gradients()
26+
self.initialize()
27+
1128
def _setup_graph(self, ob_space, ac_space):
1229
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
1330
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
@@ -25,6 +42,29 @@ def setup_loss(self, action_space):
2542
self.pi_loss = - tf.reduce_sum(log_prob)
2643
self.loss = self.pi_loss
2744

45+
def setup_gradients(self):
46+
grads = tf.gradients(self.loss, self.var_list)
47+
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
48+
grads_and_vars = list(zip(self.grads, self.var_list))
49+
opt = tf.train.AdamOptimizer(self.config["lr"])
50+
self._apply_gradients = opt.apply_gradients(grads_and_vars)
51+
52+
def initialize(self):
53+
if self.summarize:
54+
bs = tf.to_float(tf.shape(self.x)[0])
55+
tf.summary.scalar("model/policy_loss", self.pi_loss / bs)
56+
tf.summary.scalar("model/grad_gnorm", tf.global_norm(self.grads))
57+
tf.summary.scalar("model/var_gnorm", tf.global_norm(self.var_list))
58+
self.summary_op = tf.summary.merge_all()
59+
60+
# TODO(rliaw): Can consider exposing these parameters
61+
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
62+
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2,
63+
gpu_options=tf.GPUOptions(allow_growth=True)))
64+
self.variables = ray.experimental.TensorFlowVariables(self.loss,
65+
self.sess)
66+
self.sess.run(tf.global_variables_initializer())
67+
2868
def compute_gradients(self, samples):
2969
info = {}
3070
feed_dict = {
@@ -42,6 +82,18 @@ def compute_gradients(self, samples):
4282
info["loss"] = loss
4383
return grad, info
4484

85+
def apply_gradients(self, grads):
86+
feed_dict = {self.grads[i]: grads[i]
87+
for i in range(len(grads))}
88+
self.sess.run(self._apply_gradients, feed_dict=feed_dict)
89+
90+
def get_weights(self):
91+
weights = self.variables.get_weights()
92+
return weights
93+
94+
def set_weights(self, weights):
95+
self.variables.set_weights(weights)
96+
4597
def compute(self, ob, *args):
4698
action = self.sess.run(self.sample, {self.x: [ob]})
4799
return action, None

python/ray/rllib/bc/tfpolicy.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

python/ray/rllib/eval.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

python/ray/rllib/rollout.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@
3535
"tune registry.")
3636
required_named.add_argument(
3737
"--env", type=str, help="The gym environment to use.")
38-
required_named.add_argument(
39-
"--steps", type=str, help="Number of steps to roll out.")
40-
required_named.add_argument(
41-
"--out", type=str, help="Output filename.")
38+
parser.add_argument(
39+
"--no-render", default=False, action="store_const", const=True,
40+
help="Surpress rendering of the environment.")
41+
parser.add_argument(
42+
"--steps", default=None, help="Number of steps to roll out.")
43+
parser.add_argument(
44+
"--out", default=None, help="Output filename.")
4245
parser.add_argument(
4346
"--config", default="{}", type=json.loads,
4447
help="Algorithm-specific configuration (e.g. env, hyperparams), ")
@@ -59,16 +62,23 @@
5962
num_steps = int(args.steps)
6063

6164
env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(), gym.make(args.env))
62-
rollouts = []
65+
if args.out is not None:
66+
rollouts = []
6367
steps = 0
64-
while steps < num_steps:
65-
rollout = []
68+
while steps < (num_steps or steps + 1):
69+
if args.out is not None:
70+
rollout = []
6671
state = env.reset()
6772
done = False
68-
while not done and steps < num_steps:
73+
while not done and steps < (num_steps or steps + 1):
6974
action = agent.compute_action(state)
7075
next_state, reward, done, _ = env.step(action)
71-
rollout.append([state, action, next_state, reward, done])
76+
if not args.no_render:
77+
env.render()
78+
if args.out is not None:
79+
rollout.append([state, action, next_state, reward, done])
7280
steps += 1
73-
rollouts.append(rollout)
74-
pickle.dump(rollouts, open(args.out, "wb"))
81+
if args.out is not None:
82+
rollouts.append(rollout)
83+
if args.out is not None:
84+
pickle.dump(rollouts, open(args.out, "wb"))

0 commit comments

Comments
 (0)