-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathjepa.py
More file actions
154 lines (119 loc) · 4.99 KB
/
jepa.py
File metadata and controls
154 lines (119 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""JEPA Implementation"""
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
def detach_clone(v):
return v.detach().clone() if torch.is_tensor(v) else v
class JEPA(nn.Module):
def __init__(
self,
encoder,
predictor,
action_encoder,
projector=None,
pred_proj=None,
):
super().__init__()
self.encoder = encoder
self.predictor = predictor
self.action_encoder = action_encoder
self.projector = projector or nn.Identity()
self.pred_proj = pred_proj or nn.Identity()
def encode(self, info):
"""Encode observations and actions into embeddings.
info: dict with pixels and action keys
"""
pixels = info['pixels'].float()
b = pixels.size(0)
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
output = self.encoder(pixels, interpolate_pos_encoding=True)
pixels_emb = output.last_hidden_state[:, 0] # cls token
emb = self.projector(pixels_emb)
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
if "action" in info:
info["act_emb"] = self.action_encoder(info["action"])
return info
def predict(self, emb, act_emb):
"""Predict next state embedding
emb: (B, T, D)
act_emb: (B, T, A_emb)
"""
preds = self.predictor(emb, act_emb)
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
return preds
####################
## Inference only ##
####################
def rollout(self, info, action_sequence, history_size: int = 3):
"""Rollout the model given an initial info dict and action sequence.
pixels: (B, S, T, C, H, W)
action_sequence: (B, S, T, action_dim)
- S is the number of action plan samples
- T is the time horizon
"""
assert "pixels" in info, "pixels not in info_dict"
H = info["pixels"].size(2)
B, S, T = action_sequence.shape[:3]
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
info["action"] = act_0
n_steps = T - H
# copy and encode initial info dict
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
_init = self.encode(_init)
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
_init = {k: detach_clone(v) for k, v in _init.items()}
# flatten batch and sample dimensions for rollout
emb = rearrange(emb, "b s ... -> (b s) ...").clone()
act = rearrange(act_0, "b s ... -> (b s) ...")
act_future = rearrange(act_future, "b s ... -> (b s) ...")
# rollout predictor autoregressively for n_steps
HS = history_size
for t in range(n_steps):
act_emb = self.action_encoder(act)
emb_trunc = emb[:, -HS:]
act_trunc = act_emb[:, -HS:]
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:]
emb = torch.cat([emb, pred_emb], dim=1)
next_act = act_future[:, t : t + 1, :]
act = torch.cat([act, next_act], dim=1)
# predict the last state
act_emb = self.action_encoder(act)
emb_trunc = emb[:, -HS:]
act_trunc = act_emb[:, -HS:]
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:]
emb = torch.cat([emb, pred_emb], dim=1)
# unflatten batch and sample dimensions
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
info["predicted_emb"] = pred_rollout
return info
def criterion(self, info_dict: dict):
"""Compute the cost between predicted embeddings and goal embeddings."""
pred_emb = info_dict["predicted_emb"]
goal_emb = info_dict["goal_emb"]
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
cost = F.mse_loss(
pred_emb[..., -1:, :],
goal_emb[..., -1:, :].detach(),
reduction="none",
).sum(dim=tuple(range(2, pred_emb.ndim)))
return cost
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
"""Compute the cost of action candidates given an info dict with goal and initial state."""
assert "goal" in info_dict, "goal not in info_dict"
device = next(self.parameters()).device
for k in list(info_dict.keys()):
if torch.is_tensor(info_dict[k]):
info_dict[k] = info_dict[k].to(device)
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
goal["pixels"] = goal["goal"]
for k in info_dict:
if k.startswith("goal_"):
goal[k[len("goal_") :]] = goal.pop(k)
goal.pop("action")
goal = self.encode(goal)
info_dict["goal_emb"] = goal["emb"]
info_dict = self.rollout(info_dict, action_candidates)
cost = self.criterion(info_dict)
return cost