Skip to content

Commit 75cabb5

Browse files
committed
Update
1 parent 0b90f2b commit 75cabb5

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

xla/python/xla_extension/__init__.pyi

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -696,9 +696,16 @@ class DeviceTopology:
696696
def buffer_to_dlpack_managed_tensor(
697697
buffer: ArrayImpl, stream: int | None = None
698698
) -> Any: ...
699+
@overload
699700
def dlpack_managed_tensor_to_buffer(
700701
tensor: Any, device: Device, stream: int | None
701702
) -> ArrayImpl: ...
703+
@overload
704+
def dlpack_managed_tensor_to_buffer( # Legacy overload
705+
tensor: Any,
706+
cpu_backend: Optional[Client] = ...,
707+
gpu_backend: Optional[Client] = ...,
708+
) -> ArrayImpl: ...
702709

703710
def cuda_array_interface_to_buffer(
704711
cai: Dict[str, Union[
@@ -710,12 +717,6 @@ def cuda_array_interface_to_buffer(
710717
gpu_backend: Optional[Client] = ...,
711718
) -> ArrayImpl: ...
712719

713-
# Legacy overload
714-
def dlpack_managed_tensor_to_buffer(
715-
tensor: Any,
716-
cpu_backend: Optional[Client] = ...,
717-
gpu_backend: Optional[Client] = ...,
718-
) -> ArrayImpl: ...
719720

720721
# === BEGIN py_traceback.cc
721722

0 commit comments

Comments
 (0)