Skip to content

Commit 46334b5

Browse files
authored
xp2{tensorflow,torch}: convert NumPy arrays using dlpack (#686)
* xp2{tensorflow,torch}: convert NumPy arrays using dlpack Newer versions of NumPy can expose arrays as dlpack capsules. Use this functionality (when supported) to speed up NumPy -> Torch/Tensorflow array conversion. * Fix up copy paste error
1 parent abf7d31 commit 46334b5

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

thinc/util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ def xp2torch(
314314
if hasattr(xp_tensor, "toDlpack"):
315315
dlpack_tensor = xp_tensor.toDlpack() # type: ignore
316316
torch_tensor = torch.utils.dlpack.from_dlpack(dlpack_tensor)
317+
elif hasattr(xp_tensor, "__dlpack__"):
318+
torch_tensor = torch.utils.dlpack.from_dlpack(xp_tensor)
317319
else:
318320
torch_tensor = torch.from_numpy(xp_tensor)
319321
if requires_grad:
@@ -350,6 +352,8 @@ def xp2tensorflow(
350352
if hasattr(xp_tensor, "toDlpack"):
351353
dlpack_tensor = xp_tensor.toDlpack() # type: ignore
352354
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_tensor)
355+
elif hasattr(xp_tensor, "__dlpack__"):
356+
tf_tensor = tf.experimental.dlpack.from_dlpack(xp_tensor)
353357
else:
354358
tf_tensor = tf.convert_to_tensor(xp_tensor)
355359
if as_variable:

0 commit comments

Comments
 (0)