We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f02923a commit 7cf9d93Copy full SHA for 7cf9d93
src/transformers/modeling_utils.py
@@ -1698,6 +1698,10 @@ def _check_and_enable_flash_attn_2(
1698
raise ImportError(
1699
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
1700
)
1701
+ elif not torch.cuda.is_available():
1702
+ raise ValueError(
1703
+ f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device."
1704
+ )
1705
else:
1706
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
1707
elif torch.version.hip:
0 commit comments