diff --git a/Cargo.toml b/Cargo.toml index a648b09bc..c75066c5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } rawpointer = { version = "0.2" } +dlpark = { version = "0.3.0", optional = true } + [dev-dependencies] defmac = "0.2" quickcheck = { version = "1.0", default-features = false } @@ -73,6 +75,8 @@ rayon = ["rayon_", "std"] matrixmultiply-threading = ["matrixmultiply/threading"] +dlpack = ["dep:dlpark"] + [profile.bench] debug = true [profile.dev.package.numeric-tests] diff --git a/src/data_traits.rs b/src/data_traits.rs index acf4b0b7a..7095db73d 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -17,9 +17,12 @@ use alloc::sync::Arc; use alloc::vec::Vec; use crate::{ - ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, + ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr }; +#[cfg(feature = "dlpack")] +use crate::ManagedRepr; + /// Array representation trait. /// /// For an array that meets the invariants of the `ArrayBase` type. This trait @@ -346,6 +349,24 @@ unsafe impl<A> RawData for OwnedRepr<A> { private_impl! {} } +#[cfg(feature = "dlpack")] +unsafe impl<A> RawData for ManagedRepr<A> { + type Elem = A; + + fn _data_slice(&self) -> Option<&[A]> { + Some(self.as_slice()) + } + + fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool { + let slc = self.as_slice(); + let ptr = slc.as_ptr() as *mut A; + let end = unsafe { ptr.add(slc.len()) }; + self_ptr >= ptr && self_ptr <= end + } + + private_impl! {} +} + unsafe impl<A> RawDataMut for OwnedRepr<A> { #[inline] fn try_ensure_unique<D>(_: &mut ArrayBase<Self, D>) @@ -382,6 +403,28 @@ unsafe impl<A> Data for OwnedRepr<A> { } } +#[cfg(feature = "dlpack")] +unsafe impl<A> Data for ManagedRepr<A> { + #[inline] + fn into_owned<D>(self_: ArrayBase<Self, D>) -> Array<Self::Elem, D> + where + A: Clone, + D: Dimension, + { + self_.to_owned() + } + + #[inline] + fn try_into_owned_nocopy<D>( + self_: ArrayBase<Self, D>, + ) -> Result<Array<Self::Elem, D>, ArrayBase<Self, D>> + where + D: Dimension, + { + Err(self_) + } +} + unsafe impl<A> DataMut for OwnedRepr<A> {} unsafe impl<A> RawDataClone for OwnedRepr<A> diff --git a/src/dlpack.rs b/src/dlpack.rs new file mode 100644 index 000000000..d80919e11 --- /dev/null +++ b/src/dlpack.rs @@ -0,0 +1,94 @@ +use core::ptr::NonNull; +use std::marker::PhantomData; + +use dlpark::prelude::*; + +use crate::{ArrayBase, Dimension, IntoDimension, IxDyn, ManagedArray, RawData}; + +impl<A, S, D> ToTensor for ArrayBase<S, D> +where + A: InferDtype, + S: RawData<Elem = A>, + D: Dimension, +{ + fn data_ptr(&self) -> *mut std::ffi::c_void { + self.as_ptr() as *mut std::ffi::c_void + } + + fn byte_offset(&self) -> u64 { + 0 + } + + fn device(&self) -> Device { + Device::CPU + } + + fn dtype(&self) -> DataType { + A::infer_dtype() + } + + fn shape(&self) -> CowIntArray { + dlpark::prelude::CowIntArray::from_owned( + self.shape().into_iter().map(|&x| x as i64).collect(), + ) + } + + fn strides(&self) -> Option<CowIntArray> { + Some(dlpark::prelude::CowIntArray::from_owned( + self.strides().into_iter().map(|&x| x as i64).collect(), + )) + } +} + +pub struct ManagedRepr<A> { + managed_tensor: ManagedTensor, + _ty: PhantomData<A>, +} + +impl<A> ManagedRepr<A> { + pub fn new(managed_tensor: ManagedTensor) -> Self { + Self { + managed_tensor, + _ty: PhantomData, + } + } + + pub fn as_slice(&self) -> &[A] { + self.managed_tensor.as_slice() + } + + pub fn as_ptr(&self) -> *const A { + self.managed_tensor.data_ptr() as *const A + } +} + +unsafe impl<A> Sync for ManagedRepr<A> where A: Sync {} +unsafe impl<A> Send for ManagedRepr<A> where A: Send {} + +impl<A> FromDLPack for ManagedArray<A, IxDyn> { + fn from_dlpack(dlpack: NonNull<dlpark::ffi::DLManagedTensor>) -> Self { + let managed_tensor = ManagedTensor::new(dlpack); + let shape: Vec<usize> = managed_tensor + .shape() + .into_iter() + .map(|x| *x as _) + .collect(); + + let strides: Vec<usize> = match (managed_tensor.strides(), managed_tensor.is_contiguous()) { + (Some(s), _) => s.into_iter().map(|&x| x as _).collect(), + (None, true) => managed_tensor + .calculate_contiguous_strides() + .into_iter() + .map(|x| x as _) + .collect(), + (None, false) => panic!("dlpack: invalid strides"), + }; + let ptr = managed_tensor.data_ptr() as *mut A; + + let managed_repr = ManagedRepr::<A>::new(managed_tensor); + unsafe { + ArrayBase::from_data_ptr(managed_repr, NonNull::new_unchecked(ptr)) + .with_strides_dim(strides.into_dimension(), shape.into_dimension()) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 07e5ed680..2927f80fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,6 +207,9 @@ mod zip; mod dimension; +#[cfg(feature = "dlpack")] +mod dlpack; + pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip}; pub use crate::layout::Layout; @@ -1346,6 +1349,12 @@ pub type Array<A, D> = ArrayBase<OwnedRepr<A>, D>; /// instead of either a view or a uniquely owned copy. pub type CowArray<'a, A, D> = ArrayBase<CowRepr<'a, A>, D>; + +/// An array from managed memory +#[cfg(feature = "dlpack")] +pub type ManagedArray<A, D> = ArrayBase<ManagedRepr<A>, D>; + + /// A read-only array view. /// /// An array view represents an array or a part of it, created from @@ -1420,6 +1429,10 @@ pub type RawArrayViewMut<A, D> = ArrayBase<RawViewRepr<*mut A>, D>; pub use data_repr::OwnedRepr; +#[cfg(feature = "dlpack")] +pub use dlpack::ManagedRepr; + + /// ArcArray's representation. /// /// *Don’t use this type directly—use the type alias diff --git a/tests/dlpack.rs b/tests/dlpack.rs new file mode 100644 index 000000000..c0ba6e307 --- /dev/null +++ b/tests/dlpack.rs @@ -0,0 +1,17 @@ +#![cfg(feature = "dlpack")] + +use dlpark::prelude::*; +use ndarray::ManagedArray; + +#[test] +fn test_dlpack() { + let arr = ndarray::arr1(&[1i32, 2, 3]); + let ptr = arr.as_ptr(); + let dlpack = arr.into_dlpack(); + let arr2 = ManagedArray::<i32, _>::from_dlpack(dlpack); + let ptr2 = arr2.as_ptr(); + assert_eq!(ptr, ptr2); + let arr3 = arr2.to_owned(); + let ptr3 = arr3.as_ptr(); + assert_ne!(ptr2, ptr3); +}