Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
T5Config,
)
from transformers.testing_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
is_torch_available,
require_flash_attn,
require_torch,
Expand Down Expand Up @@ -1095,10 +1097,14 @@ def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="Model architecture does not support attentions")

torch.compiler.reset()
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability

if not torch.version.cuda or major < 8:
valid_compute_capability = False
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
valid_compute_capability = major < 8

if not valid_compute_capability:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")

for model_class in self.all_model_classes:
Expand Down
12 changes: 9 additions & 3 deletions tests/models/musicgen_melody/test_modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
T5Config,
)
from transformers.testing_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
is_torch_available,
is_torchaudio_available,
require_flash_attn,
Expand Down Expand Up @@ -1085,10 +1087,14 @@ def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="Model architecture does not support attentions")

torch.compiler.reset()
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability

if not torch.version.cuda or major < 8:
valid_compute_capability = False
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
valid_compute_capability = major < 8

if not valid_compute_capability:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")

for model_class in self.all_model_classes:
Expand Down
22 changes: 16 additions & 6 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
MODEL_MAPPING_NAMES,
)
from transformers.testing_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
CaptureLogger,
hub_retry,
is_flaky,
Expand Down Expand Up @@ -3766,10 +3768,14 @@ def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="Model architecture does not support attentions")

torch.compiler.reset()
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability

if not torch.version.cuda or major < 8:
valid_compute_capability = False
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
valid_compute_capability = major < 8

if not valid_compute_capability:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")

for model_class in self.all_model_classes:
Expand Down Expand Up @@ -3810,13 +3816,17 @@ def test_sdpa_can_dispatch_on_flash(self):
def test_sdpa_can_compile_dynamic(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

torch.compiler.reset()
if "cuda" in torch_device:

valid_compute_capability = False
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
valid_compute_capability = major < 8

if not torch.version.cuda or major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
if not valid_compute_capability:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")

for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
Expand Down