@@ -183,12 +183,12 @@ def test_forward_backward(self, model_name):
183
183
out_agg = 0
184
184
for node_out in out .values ():
185
185
if isinstance (node_out , Sequence ):
186
- out_agg += sum (o .mean () for o in node_out if o is not None )
186
+ out_agg += sum (o .float (). mean () for o in node_out if o is not None )
187
187
elif isinstance (node_out , Mapping ):
188
- out_agg += sum (o .mean () for o in node_out .values () if o is not None )
188
+ out_agg += sum (o .float (). mean () for o in node_out .values () if o is not None )
189
189
else :
190
190
# Assume that the only other alternative at this point is a Tensor
191
- out_agg += node_out .mean ()
191
+ out_agg += node_out .float (). mean ()
192
192
out_agg .backward ()
193
193
194
194
def test_feature_extraction_methods_equivalence (self ):
@@ -224,12 +224,12 @@ def test_jit_forward_backward(self, model_name):
224
224
out_agg = 0
225
225
for node_out in fgn_out .values ():
226
226
if isinstance (node_out , Sequence ):
227
- out_agg += sum (o .mean () for o in node_out if o is not None )
227
+ out_agg += sum (o .float (). mean () for o in node_out if o is not None )
228
228
elif isinstance (node_out , Mapping ):
229
- out_agg += sum (o .mean () for o in node_out .values () if o is not None )
229
+ out_agg += sum (o .float (). mean () for o in node_out .values () if o is not None )
230
230
else :
231
231
# Assume that the only other alternative at this point is a Tensor
232
- out_agg += node_out .mean ()
232
+ out_agg += node_out .float (). mean ()
233
233
out_agg .backward ()
234
234
235
235
def test_train_eval (self ):
@@ -239,7 +239,7 @@ def __init__(self):
239
239
self .dropout = torch .nn .Dropout (p = 1.0 )
240
240
241
241
def forward (self , x ):
242
- x = x .mean ()
242
+ x = x .float (). mean ()
243
243
x = self .dropout (x ) # dropout
244
244
if self .training :
245
245
x += 100 # add
@@ -330,4 +330,4 @@ def forward(self, x):
330
330
# Check forward
331
331
out = model (self .inp )
332
332
# And backward
333
- out ["leaf_module" ].mean ().backward ()
333
+ out ["leaf_module" ].float (). mean ().backward ()
0 commit comments