Skip to content

Commit 8e84383

Browse files
authored
[misc] fix cuda warn on intel GPU (hiyouga#7655)
1 parent dfc63d8 commit 8e84383

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

src/llamafactory/extras/misc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def get_peak_memory() -> tuple[int, int]:
179179
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
180180
elif is_torch_cuda_available():
181181
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
182+
elif is_torch_xpu_available():
183+
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
182184
else:
183185
return 0, 0
184186

@@ -200,7 +202,7 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
200202

201203
def is_gpu_or_npu_available() -> bool:
202204
r"""Check if the GPU or NPU is available."""
203-
return is_torch_npu_available() or is_torch_cuda_available()
205+
return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available()
204206

205207

206208
def is_env_enabled(env_var: str, default: str = "0") -> bool:

0 commit comments

Comments
 (0)