Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
T5Config,
)
from transformers.testing_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
is_torch_available,
require_flash_attn,
require_torch,
Expand Down Expand Up @@ -1094,11 +1096,15 @@ def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="Model architecture does not support attentions")

torch.compiler.reset()
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability

if not torch.version.cuda or major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
if IS_CUDA_SYSTEM and major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif IS_ROCM_SYSTEM and major < 9:
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
else:
self.skipTest(reason="This test requires a Nvidia or AMD GPU")

for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
Expand Down
16 changes: 11 additions & 5 deletions tests/models/musicgen_melody/test_modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
T5Config,
)
from transformers.testing_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
is_torch_available,
is_torchaudio_available,
require_flash_attn,
Expand Down Expand Up @@ -1084,11 +1086,15 @@ def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="Model architecture does not support attentions")

torch.compiler.reset()
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability

if not torch.version.cuda or major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
if IS_CUDA_SYSTEM and major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif IS_ROCM_SYSTEM and major < 9:
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
else:
self.skipTest(reason="This test requires a Nvidia or AMD GPU")

for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
Expand Down
26 changes: 18 additions & 8 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
MODEL_MAPPING_NAMES,
)
from transformers.testing_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
CaptureLogger,
hub_retry,
is_flaky,
Expand Down Expand Up @@ -3764,11 +3766,15 @@ def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="Model architecture does not support attentions")

torch.compiler.reset()
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability

if not torch.version.cuda or major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
if IS_CUDA_SYSTEM and major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif IS_ROCM_SYSTEM and major < 9:
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
else:
self.skipTest(reason="This test requires a Nvidia or AMD GPU")

for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
Expand Down Expand Up @@ -3808,13 +3814,17 @@ def test_sdpa_can_dispatch_on_flash(self):
def test_sdpa_can_compile_dynamic(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

torch.compiler.reset()
if "cuda" in torch_device:
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability

if not torch.version.cuda or major < 8:
if IS_CUDA_SYSTEM and major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif IS_ROCM_SYSTEM and major < 9:
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
else:
self.skipTest(reason="This test requires a Nvidia or AMD GPU")

for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
Expand Down