File tree 1 file changed +4
-2
lines changed
1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -698,11 +698,13 @@ def _test_xformers_attention_forwardGenerator_pass(
698
698
pipe .set_progress_bar_config (disable = None )
699
699
700
700
inputs = self .get_dummy_inputs (torch_device )
701
- output_without_offload = pipe (** inputs )[0 ].cpu ()
701
+ output_without_offload = pipe (** inputs )[0 ]
702
+ output_without_offload .cpu () if torch .is_tensor (output_without_offload ) else output_without_offload
702
703
703
704
pipe .enable_xformers_memory_efficient_attention ()
704
705
inputs = self .get_dummy_inputs (torch_device )
705
- output_with_offload = pipe (** inputs )[0 ].cpu ()
706
+ output_with_offload = pipe (** inputs )[0 ]
707
+ output_with_offload .cpu () if torch .is_tensor (output_with_offload ) else output_without_offload
706
708
707
709
if test_max_difference :
708
710
max_diff = np .abs (output_with_offload - output_without_offload ).max ()
You can’t perform that action at this time.
0 commit comments