Skip to content

Commit 7df2d08

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Add .float() before .mean() on test_backbone_utils.py because .mean() dont accept integer dtype (#6090)
Reviewed By: NicolasHug Differential Revision: D36760937 fbshipit-source-id: dfe3ef93953f9f7d4e55cefa40f177e574a5271f
1 parent d28ec48 commit 7df2d08

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

test/test_backbone_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,12 @@ def test_forward_backward(self, model_name):
183183
out_agg = 0
184184
for node_out in out.values():
185185
if isinstance(node_out, Sequence):
186-
out_agg += sum(o.mean() for o in node_out if o is not None)
186+
out_agg += sum(o.float().mean() for o in node_out if o is not None)
187187
elif isinstance(node_out, Mapping):
188-
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
188+
out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
189189
else:
190190
# Assume that the only other alternative at this point is a Tensor
191-
out_agg += node_out.mean()
191+
out_agg += node_out.float().mean()
192192
out_agg.backward()
193193

194194
def test_feature_extraction_methods_equivalence(self):
@@ -224,12 +224,12 @@ def test_jit_forward_backward(self, model_name):
224224
out_agg = 0
225225
for node_out in fgn_out.values():
226226
if isinstance(node_out, Sequence):
227-
out_agg += sum(o.mean() for o in node_out if o is not None)
227+
out_agg += sum(o.float().mean() for o in node_out if o is not None)
228228
elif isinstance(node_out, Mapping):
229-
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
229+
out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
230230
else:
231231
# Assume that the only other alternative at this point is a Tensor
232-
out_agg += node_out.mean()
232+
out_agg += node_out.float().mean()
233233
out_agg.backward()
234234

235235
def test_train_eval(self):
@@ -239,7 +239,7 @@ def __init__(self):
239239
self.dropout = torch.nn.Dropout(p=1.0)
240240

241241
def forward(self, x):
242-
x = x.mean()
242+
x = x.float().mean()
243243
x = self.dropout(x) # dropout
244244
if self.training:
245245
x += 100 # add
@@ -330,4 +330,4 @@ def forward(self, x):
330330
# Check forward
331331
out = model(self.inp)
332332
# And backward
333-
out["leaf_module"].mean().backward()
333+
out["leaf_module"].float().mean().backward()

0 commit comments

Comments
 (0)