File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -179,6 +179,8 @@ def get_peak_memory() -> tuple[int, int]:
179179 return torch .npu .max_memory_allocated (), torch .npu .max_memory_reserved ()
180180 elif is_torch_cuda_available ():
181181 return torch .cuda .max_memory_allocated (), torch .cuda .max_memory_reserved ()
182+ elif is_torch_xpu_available ():
183+ return torch .xpu .max_memory_allocated (), torch .xpu .max_memory_reserved ()
182184 else :
183185 return 0 , 0
184186
@@ -200,7 +202,7 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
200202
201203def is_gpu_or_npu_available () -> bool :
202204 r"""Check if the GPU or NPU is available."""
203- return is_torch_npu_available () or is_torch_cuda_available ()
205+ return is_torch_npu_available () or is_torch_cuda_available () or is_torch_xpu_available ()
204206
205207
206208def is_env_enabled (env_var : str , default : str = "0" ) -> bool :
You can’t perform that action at this time.
0 commit comments