From 512c1bfc3e6adb4a5dc1c73986d3df90e90ce3f9 Mon Sep 17 00:00:00 2001
From: SunDoge <384813529@qq.com>
Date: Fri, 14 Jul 2023 12:42:51 +0800
Subject: [PATCH 1/3] add dlpack support
---
Cargo.toml | 2 +
src/data_traits.rs | 42 ++++++++++++++++++++-
src/dlpack.rs | 93 ++++++++++++++++++++++++++++++++++++++++++++++
src/lib.rs | 9 +++++
tests/dlpack.rs | 17 +++++++++
5 files changed, 162 insertions(+), 1 deletion(-)
create mode 100644 src/dlpack.rs
create mode 100644 tests/dlpack.rs
diff --git a/Cargo.toml b/Cargo.toml
index a648b09bc..daf4c616e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -46,6 +46,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
rawpointer = { version = "0.2" }
+dlpark = { git = "https://github.com/SunDoge/dlpark", rev = "f4d45cd" }
+
[dev-dependencies]
defmac = "0.2"
quickcheck = { version = "1.0", default-features = false }
diff --git a/src/data_traits.rs b/src/data_traits.rs
index acf4b0b7a..df8907ce5 100644
--- a/src/data_traits.rs
+++ b/src/data_traits.rs
@@ -17,7 +17,7 @@ use alloc::sync::Arc;
use alloc::vec::Vec;
use crate::{
- ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr,
+ ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, ManagedRepr
};
/// Array representation trait.
@@ -346,6 +346,24 @@ unsafe impl RawData for OwnedRepr {
private_impl! {}
}
+
+unsafe impl RawData for ManagedRepr {
+ type Elem = A;
+
+ fn _data_slice(&self) -> Option<&[A]> {
+ Some(self.as_slice())
+ }
+
+ fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool {
+ let slc = self.as_slice();
+ let ptr = slc.as_ptr() as *mut A;
+ let end = unsafe { ptr.add(slc.len()) };
+ self_ptr >= ptr && self_ptr <= end
+ }
+
+ private_impl! {}
+}
+
unsafe impl RawDataMut for OwnedRepr {
#[inline]
fn try_ensure_unique(_: &mut ArrayBase)
@@ -382,6 +400,28 @@ unsafe impl Data for OwnedRepr {
}
}
+
+unsafe impl Data for ManagedRepr {
+ #[inline]
+ fn into_owned(self_: ArrayBase) -> Array
+ where
+ A: Clone,
+ D: Dimension,
+ {
+ self_.to_owned()
+ }
+
+ #[inline]
+ fn try_into_owned_nocopy(
+ self_: ArrayBase,
+ ) -> Result, ArrayBase>
+ where
+ D: Dimension,
+ {
+ Err(self_)
+ }
+}
+
unsafe impl DataMut for OwnedRepr {}
unsafe impl RawDataClone for OwnedRepr
diff --git a/src/dlpack.rs b/src/dlpack.rs
new file mode 100644
index 000000000..32b945e8e
--- /dev/null
+++ b/src/dlpack.rs
@@ -0,0 +1,93 @@
+use core::ptr::NonNull;
+use std::marker::PhantomData;
+
+use dlpark::prelude::*;
+
+use crate::{ArrayBase, Dimension, IntoDimension, IxDyn, ManagedArray, RawData};
+
+impl ToTensor for ArrayBase
+where
+ A: InferDtype,
+ S: RawData,
+ D: Dimension,
+{
+ fn data_ptr(&self) -> *mut std::ffi::c_void {
+ self.as_ptr() as *mut std::ffi::c_void
+ }
+
+ fn byte_offset(&self) -> u64 {
+ 0
+ }
+
+ fn device(&self) -> Device {
+ Device::CPU
+ }
+
+ fn dtype(&self) -> DataType {
+ A::infer_dtype()
+ }
+
+ fn shape(&self) -> CowIntArray {
+ dlpark::prelude::CowIntArray::from_owned(
+ self.shape().into_iter().map(|&x| x as i64).collect(),
+ )
+ }
+
+ fn strides(&self) -> Option {
+ Some(dlpark::prelude::CowIntArray::from_owned(
+ self.strides().into_iter().map(|&x| x as i64).collect(),
+ ))
+ }
+}
+
+pub struct ManagedRepr {
+ managed_tensor: ManagedTensor,
+ _ty: PhantomData,
+}
+
+impl ManagedRepr {
+ pub fn new(managed_tensor: ManagedTensor) -> Self {
+ Self {
+ managed_tensor,
+ _ty: PhantomData,
+ }
+ }
+
+ pub fn as_slice(&self) -> &[A] {
+ self.managed_tensor.as_slice()
+ }
+
+ pub fn as_ptr(&self) -> *const A {
+ self.managed_tensor.data_ptr() as *const A
+ }
+}
+
+unsafe impl Sync for ManagedRepr where A: Sync {}
+unsafe impl Send for ManagedRepr where A: Send {}
+
+impl FromDLPack for ManagedArray {
+ fn from_dlpack(dlpack: NonNull) -> Self {
+ let managed_tensor = ManagedTensor::new(dlpack);
+ let shape: Vec = managed_tensor
+ .shape()
+ .into_iter()
+ .map(|x| *x as _)
+ .collect();
+
+ let strides: Vec = match (managed_tensor.strides(), managed_tensor.is_contiguous()) {
+ (Some(s), _) => s.into_iter().map(|&x| x as _).collect(),
+ (None, true) => dlpark::tensor::calculate_contiguous_strides(managed_tensor.shape())
+ .into_iter()
+ .map(|x| x as _)
+ .collect(),
+ (None, false) => panic!("fail"),
+ };
+ let ptr = managed_tensor.data_ptr() as *mut A;
+
+ let managed_repr = ManagedRepr::::new(managed_tensor);
+ unsafe {
+ ArrayBase::from_data_ptr(managed_repr, NonNull::new_unchecked(ptr))
+ .with_strides_dim(strides.into_dimension(), shape.into_dimension())
+ }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 07e5ed680..526692943 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -207,6 +207,8 @@ mod zip;
mod dimension;
+mod dlpack;
+
pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip};
pub use crate::layout::Layout;
@@ -1346,6 +1348,11 @@ pub type Array = ArrayBase, D>;
/// instead of either a view or a uniquely owned copy.
pub type CowArray<'a, A, D> = ArrayBase, D>;
+
+/// An array from managed memory
+pub type ManagedArray = ArrayBase, D>;
+
+
/// A read-only array view.
///
/// An array view represents an array or a part of it, created from
@@ -1419,6 +1426,8 @@ pub type RawArrayView = ArrayBase, D>;
pub type RawArrayViewMut = ArrayBase, D>;
pub use data_repr::OwnedRepr;
+pub use dlpack::ManagedRepr;
+
/// ArcArray's representation.
///
diff --git a/tests/dlpack.rs b/tests/dlpack.rs
new file mode 100644
index 000000000..064441d94
--- /dev/null
+++ b/tests/dlpack.rs
@@ -0,0 +1,17 @@
+use dlpark::prelude::*;
+use ndarray::ManagedArray;
+
+#[test]
+fn test_dlpack() {
+ let arr = ndarray::arr1(&[1i32, 2, 3]);
+ let ptr = arr.as_ptr();
+ let dlpack = arr.to_dlpack();
+ let arr2 = ManagedArray::::from_dlpack(dlpack);
+ let ptr2 = arr2.as_ptr();
+ assert_eq!(ptr, ptr2);
+ // dbg!(&arr2);
+ let arr3 = arr2.to_owned();
+ // dbg!(&arr3);
+ let ptr3 = arr3.as_ptr();
+ assert_ne!(ptr2, ptr3);
+}
From 32754cac8d4470d5b0356860b8db1db723f3646c Mon Sep 17 00:00:00 2001
From: SunDoge <384813529@qq.com>
Date: Tue, 18 Jul 2023 17:25:54 +0800
Subject: [PATCH 2/3] add feature gate
---
Cargo.toml | 4 +++-
src/dlpack.rs | 5 +++--
src/lib.rs | 4 ++++
tests/dlpack.rs | 6 +++---
4 files changed, 13 insertions(+), 6 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index daf4c616e..c75066c5d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -46,7 +46,7 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
rawpointer = { version = "0.2" }
-dlpark = { git = "https://github.com/SunDoge/dlpark", rev = "f4d45cd" }
+dlpark = { version = "0.3.0", optional = true }
[dev-dependencies]
defmac = "0.2"
@@ -75,6 +75,8 @@ rayon = ["rayon_", "std"]
matrixmultiply-threading = ["matrixmultiply/threading"]
+dlpack = ["dep:dlpark"]
+
[profile.bench]
debug = true
[profile.dev.package.numeric-tests]
diff --git a/src/dlpack.rs b/src/dlpack.rs
index 32b945e8e..d80919e11 100644
--- a/src/dlpack.rs
+++ b/src/dlpack.rs
@@ -76,11 +76,12 @@ impl FromDLPack for ManagedArray {
let strides: Vec = match (managed_tensor.strides(), managed_tensor.is_contiguous()) {
(Some(s), _) => s.into_iter().map(|&x| x as _).collect(),
- (None, true) => dlpark::tensor::calculate_contiguous_strides(managed_tensor.shape())
+ (None, true) => managed_tensor
+ .calculate_contiguous_strides()
.into_iter()
.map(|x| x as _)
.collect(),
- (None, false) => panic!("fail"),
+ (None, false) => panic!("dlpack: invalid strides"),
};
let ptr = managed_tensor.data_ptr() as *mut A;
diff --git a/src/lib.rs b/src/lib.rs
index 526692943..2927f80fc 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -207,6 +207,7 @@ mod zip;
mod dimension;
+#[cfg(feature = "dlpack")]
mod dlpack;
pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip};
@@ -1350,6 +1351,7 @@ pub type CowArray<'a, A, D> = ArrayBase, D>;
/// An array from managed memory
+#[cfg(feature = "dlpack")]
pub type ManagedArray = ArrayBase, D>;
@@ -1426,6 +1428,8 @@ pub type RawArrayView = ArrayBase, D>;
pub type RawArrayViewMut = ArrayBase, D>;
pub use data_repr::OwnedRepr;
+
+#[cfg(feature = "dlpack")]
pub use dlpack::ManagedRepr;
diff --git a/tests/dlpack.rs b/tests/dlpack.rs
index 064441d94..c0ba6e307 100644
--- a/tests/dlpack.rs
+++ b/tests/dlpack.rs
@@ -1,3 +1,5 @@
+#![cfg(feature = "dlpack")]
+
use dlpark::prelude::*;
use ndarray::ManagedArray;
@@ -5,13 +7,11 @@ use ndarray::ManagedArray;
fn test_dlpack() {
let arr = ndarray::arr1(&[1i32, 2, 3]);
let ptr = arr.as_ptr();
- let dlpack = arr.to_dlpack();
+ let dlpack = arr.into_dlpack();
let arr2 = ManagedArray::::from_dlpack(dlpack);
let ptr2 = arr2.as_ptr();
assert_eq!(ptr, ptr2);
- // dbg!(&arr2);
let arr3 = arr2.to_owned();
- // dbg!(&arr3);
let ptr3 = arr3.as_ptr();
assert_ne!(ptr2, ptr3);
}
From c1a090839d6fb7bab127640005108e5f017b26f6 Mon Sep 17 00:00:00 2001
From: SunDoge <384813529@qq.com>
Date: Tue, 18 Jul 2023 17:32:03 +0800
Subject: [PATCH 3/3] fix cargo test
---
src/data_traits.rs | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/src/data_traits.rs b/src/data_traits.rs
index df8907ce5..7095db73d 100644
--- a/src/data_traits.rs
+++ b/src/data_traits.rs
@@ -17,9 +17,12 @@ use alloc::sync::Arc;
use alloc::vec::Vec;
use crate::{
- ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, ManagedRepr
+ ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr
};
+#[cfg(feature = "dlpack")]
+use crate::ManagedRepr;
+
/// Array representation trait.
///
/// For an array that meets the invariants of the `ArrayBase` type. This trait
@@ -346,7 +349,7 @@ unsafe impl RawData for OwnedRepr {
private_impl! {}
}
-
+#[cfg(feature = "dlpack")]
unsafe impl RawData for ManagedRepr {
type Elem = A;
@@ -400,7 +403,7 @@ unsafe impl Data for OwnedRepr {
}
}
-
+#[cfg(feature = "dlpack")]
unsafe impl Data for ManagedRepr {
#[inline]
fn into_owned(self_: ArrayBase) -> Array