From 512c1bfc3e6adb4a5dc1c73986d3df90e90ce3f9 Mon Sep 17 00:00:00 2001 From: SunDoge <384813529@qq.com> Date: Fri, 14 Jul 2023 12:42:51 +0800 Subject: [PATCH 1/3] add dlpack support --- Cargo.toml | 2 + src/data_traits.rs | 42 ++++++++++++++++++++- src/dlpack.rs | 93 ++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 9 +++++ tests/dlpack.rs | 17 +++++++++ 5 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 src/dlpack.rs create mode 100644 tests/dlpack.rs diff --git a/Cargo.toml b/Cargo.toml index a648b09bc..daf4c616e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } rawpointer = { version = "0.2" } +dlpark = { git = "https://github.com/SunDoge/dlpark", rev = "f4d45cd" } + [dev-dependencies] defmac = "0.2" quickcheck = { version = "1.0", default-features = false } diff --git a/src/data_traits.rs b/src/data_traits.rs index acf4b0b7a..df8907ce5 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -17,7 +17,7 @@ use alloc::sync::Arc; use alloc::vec::Vec; use crate::{ - ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, + ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, ManagedRepr }; /// Array representation trait. @@ -346,6 +346,24 @@ unsafe impl RawData for OwnedRepr { private_impl! {} } + +unsafe impl RawData for ManagedRepr { + type Elem = A; + + fn _data_slice(&self) -> Option<&[A]> { + Some(self.as_slice()) + } + + fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool { + let slc = self.as_slice(); + let ptr = slc.as_ptr() as *mut A; + let end = unsafe { ptr.add(slc.len()) }; + self_ptr >= ptr && self_ptr <= end + } + + private_impl! {} +} + unsafe impl RawDataMut for OwnedRepr { #[inline] fn try_ensure_unique(_: &mut ArrayBase) @@ -382,6 +400,28 @@ unsafe impl Data for OwnedRepr { } } + +unsafe impl Data for ManagedRepr { + #[inline] + fn into_owned(self_: ArrayBase) -> Array + where + A: Clone, + D: Dimension, + { + self_.to_owned() + } + + #[inline] + fn try_into_owned_nocopy( + self_: ArrayBase, + ) -> Result, ArrayBase> + where + D: Dimension, + { + Err(self_) + } +} + unsafe impl DataMut for OwnedRepr {} unsafe impl RawDataClone for OwnedRepr diff --git a/src/dlpack.rs b/src/dlpack.rs new file mode 100644 index 000000000..32b945e8e --- /dev/null +++ b/src/dlpack.rs @@ -0,0 +1,93 @@ +use core::ptr::NonNull; +use std::marker::PhantomData; + +use dlpark::prelude::*; + +use crate::{ArrayBase, Dimension, IntoDimension, IxDyn, ManagedArray, RawData}; + +impl ToTensor for ArrayBase +where + A: InferDtype, + S: RawData, + D: Dimension, +{ + fn data_ptr(&self) -> *mut std::ffi::c_void { + self.as_ptr() as *mut std::ffi::c_void + } + + fn byte_offset(&self) -> u64 { + 0 + } + + fn device(&self) -> Device { + Device::CPU + } + + fn dtype(&self) -> DataType { + A::infer_dtype() + } + + fn shape(&self) -> CowIntArray { + dlpark::prelude::CowIntArray::from_owned( + self.shape().into_iter().map(|&x| x as i64).collect(), + ) + } + + fn strides(&self) -> Option { + Some(dlpark::prelude::CowIntArray::from_owned( + self.strides().into_iter().map(|&x| x as i64).collect(), + )) + } +} + +pub struct ManagedRepr { + managed_tensor: ManagedTensor, + _ty: PhantomData, +} + +impl ManagedRepr { + pub fn new(managed_tensor: ManagedTensor) -> Self { + Self { + managed_tensor, + _ty: PhantomData, + } + } + + pub fn as_slice(&self) -> &[A] { + self.managed_tensor.as_slice() + } + + pub fn as_ptr(&self) -> *const A { + self.managed_tensor.data_ptr() as *const A + } +} + +unsafe impl Sync for ManagedRepr where A: Sync {} +unsafe impl Send for ManagedRepr where A: Send {} + +impl FromDLPack for ManagedArray { + fn from_dlpack(dlpack: NonNull) -> Self { + let managed_tensor = ManagedTensor::new(dlpack); + let shape: Vec = managed_tensor + .shape() + .into_iter() + .map(|x| *x as _) + .collect(); + + let strides: Vec = match (managed_tensor.strides(), managed_tensor.is_contiguous()) { + (Some(s), _) => s.into_iter().map(|&x| x as _).collect(), + (None, true) => dlpark::tensor::calculate_contiguous_strides(managed_tensor.shape()) + .into_iter() + .map(|x| x as _) + .collect(), + (None, false) => panic!("fail"), + }; + let ptr = managed_tensor.data_ptr() as *mut A; + + let managed_repr = ManagedRepr::::new(managed_tensor); + unsafe { + ArrayBase::from_data_ptr(managed_repr, NonNull::new_unchecked(ptr)) + .with_strides_dim(strides.into_dimension(), shape.into_dimension()) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 07e5ed680..526692943 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,6 +207,8 @@ mod zip; mod dimension; +mod dlpack; + pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip}; pub use crate::layout::Layout; @@ -1346,6 +1348,11 @@ pub type Array = ArrayBase, D>; /// instead of either a view or a uniquely owned copy. pub type CowArray<'a, A, D> = ArrayBase, D>; + +/// An array from managed memory +pub type ManagedArray = ArrayBase, D>; + + /// A read-only array view. /// /// An array view represents an array or a part of it, created from @@ -1419,6 +1426,8 @@ pub type RawArrayView = ArrayBase, D>; pub type RawArrayViewMut = ArrayBase, D>; pub use data_repr::OwnedRepr; +pub use dlpack::ManagedRepr; + /// ArcArray's representation. /// diff --git a/tests/dlpack.rs b/tests/dlpack.rs new file mode 100644 index 000000000..064441d94 --- /dev/null +++ b/tests/dlpack.rs @@ -0,0 +1,17 @@ +use dlpark::prelude::*; +use ndarray::ManagedArray; + +#[test] +fn test_dlpack() { + let arr = ndarray::arr1(&[1i32, 2, 3]); + let ptr = arr.as_ptr(); + let dlpack = arr.to_dlpack(); + let arr2 = ManagedArray::::from_dlpack(dlpack); + let ptr2 = arr2.as_ptr(); + assert_eq!(ptr, ptr2); + // dbg!(&arr2); + let arr3 = arr2.to_owned(); + // dbg!(&arr3); + let ptr3 = arr3.as_ptr(); + assert_ne!(ptr2, ptr3); +} From 32754cac8d4470d5b0356860b8db1db723f3646c Mon Sep 17 00:00:00 2001 From: SunDoge <384813529@qq.com> Date: Tue, 18 Jul 2023 17:25:54 +0800 Subject: [PATCH 2/3] add feature gate --- Cargo.toml | 4 +++- src/dlpack.rs | 5 +++-- src/lib.rs | 4 ++++ tests/dlpack.rs | 6 +++--- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index daf4c616e..c75066c5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } rawpointer = { version = "0.2" } -dlpark = { git = "https://github.com/SunDoge/dlpark", rev = "f4d45cd" } +dlpark = { version = "0.3.0", optional = true } [dev-dependencies] defmac = "0.2" @@ -75,6 +75,8 @@ rayon = ["rayon_", "std"] matrixmultiply-threading = ["matrixmultiply/threading"] +dlpack = ["dep:dlpark"] + [profile.bench] debug = true [profile.dev.package.numeric-tests] diff --git a/src/dlpack.rs b/src/dlpack.rs index 32b945e8e..d80919e11 100644 --- a/src/dlpack.rs +++ b/src/dlpack.rs @@ -76,11 +76,12 @@ impl FromDLPack for ManagedArray { let strides: Vec = match (managed_tensor.strides(), managed_tensor.is_contiguous()) { (Some(s), _) => s.into_iter().map(|&x| x as _).collect(), - (None, true) => dlpark::tensor::calculate_contiguous_strides(managed_tensor.shape()) + (None, true) => managed_tensor + .calculate_contiguous_strides() .into_iter() .map(|x| x as _) .collect(), - (None, false) => panic!("fail"), + (None, false) => panic!("dlpack: invalid strides"), }; let ptr = managed_tensor.data_ptr() as *mut A; diff --git a/src/lib.rs b/src/lib.rs index 526692943..2927f80fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,6 +207,7 @@ mod zip; mod dimension; +#[cfg(feature = "dlpack")] mod dlpack; pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip}; @@ -1350,6 +1351,7 @@ pub type CowArray<'a, A, D> = ArrayBase, D>; /// An array from managed memory +#[cfg(feature = "dlpack")] pub type ManagedArray = ArrayBase, D>; @@ -1426,6 +1428,8 @@ pub type RawArrayView = ArrayBase, D>; pub type RawArrayViewMut = ArrayBase, D>; pub use data_repr::OwnedRepr; + +#[cfg(feature = "dlpack")] pub use dlpack::ManagedRepr; diff --git a/tests/dlpack.rs b/tests/dlpack.rs index 064441d94..c0ba6e307 100644 --- a/tests/dlpack.rs +++ b/tests/dlpack.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "dlpack")] + use dlpark::prelude::*; use ndarray::ManagedArray; @@ -5,13 +7,11 @@ use ndarray::ManagedArray; fn test_dlpack() { let arr = ndarray::arr1(&[1i32, 2, 3]); let ptr = arr.as_ptr(); - let dlpack = arr.to_dlpack(); + let dlpack = arr.into_dlpack(); let arr2 = ManagedArray::::from_dlpack(dlpack); let ptr2 = arr2.as_ptr(); assert_eq!(ptr, ptr2); - // dbg!(&arr2); let arr3 = arr2.to_owned(); - // dbg!(&arr3); let ptr3 = arr3.as_ptr(); assert_ne!(ptr2, ptr3); } From c1a090839d6fb7bab127640005108e5f017b26f6 Mon Sep 17 00:00:00 2001 From: SunDoge <384813529@qq.com> Date: Tue, 18 Jul 2023 17:32:03 +0800 Subject: [PATCH 3/3] fix cargo test --- src/data_traits.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/data_traits.rs b/src/data_traits.rs index df8907ce5..7095db73d 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -17,9 +17,12 @@ use alloc::sync::Arc; use alloc::vec::Vec; use crate::{ - ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, ManagedRepr + ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr }; +#[cfg(feature = "dlpack")] +use crate::ManagedRepr; + /// Array representation trait. /// /// For an array that meets the invariants of the `ArrayBase` type. This trait @@ -346,7 +349,7 @@ unsafe impl RawData for OwnedRepr { private_impl! {} } - +#[cfg(feature = "dlpack")] unsafe impl RawData for ManagedRepr { type Elem = A; @@ -400,7 +403,7 @@ unsafe impl Data for OwnedRepr { } } - +#[cfg(feature = "dlpack")] unsafe impl Data for ManagedRepr { #[inline] fn into_owned(self_: ArrayBase) -> Array