Skip to content

Commit 2837d49

Browse files
Fix failing np tests (#3942)
* Fix failing np tests * Apply suggestions from code review * Update tests/pipelines/test_pipelines_common.py
1 parent 1997614 commit 2837d49

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,11 +698,13 @@ def _test_xformers_attention_forwardGenerator_pass(
698698
pipe.set_progress_bar_config(disable=None)
699699

700700
inputs = self.get_dummy_inputs(torch_device)
701-
output_without_offload = pipe(**inputs)[0].cpu()
701+
output_without_offload = pipe(**inputs)[0]
702+
output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
702703

703704
pipe.enable_xformers_memory_efficient_attention()
704705
inputs = self.get_dummy_inputs(torch_device)
705-
output_with_offload = pipe(**inputs)[0].cpu()
706+
output_with_offload = pipe(**inputs)[0]
707+
output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
706708

707709
if test_max_difference:
708710
max_diff = np.abs(output_with_offload - output_without_offload).max()

0 commit comments

Comments
 (0)