diff --git a/Cargo.toml b/Cargo.toml index 6fc2182ee8..1ab743554a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ doctest = false # Prevent downloading or building TensorFlow when building docs on docs.rs. [package.metadata.docs.rs] -features = ["private-docs-rs", "tensorflow_unstable", "ndarray", "eager"] +features = ["private-docs-rs", "tensorflow_unstable", "ndarray", "eager", "experimental"] [dependencies] libc = "0.2.132" @@ -40,6 +40,7 @@ serial_test = "0.9.0" [features] default = ["tensorflow-sys"] +experimental = ["tensorflow-sys/experimental"] tensorflow_gpu = ["tensorflow-sys/tensorflow_gpu"] tensorflow_unstable = [] tensorflow_runtime_linking = ["tensorflow-sys-runtime"] diff --git a/src/lib.rs b/src/lib.rs index 68cf077166..0694eeeeb8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,6 +54,11 @@ use tensorflow_sys as tf; #[cfg(feature = "tensorflow_runtime_linking")] use tensorflow_sys_runtime as tf; +#[cfg(feature = "experimental")] +mod pluggable_device; +#[cfg(feature = "experimental")] +pub use pluggable_device::*; + //////////////////////// /// Will panic if `msg` contains an embedded 0 byte. diff --git a/src/pluggable_device.rs b/src/pluggable_device.rs new file mode 100644 index 0000000000..512afcf0cc --- /dev/null +++ b/src/pluggable_device.rs @@ -0,0 +1,53 @@ +use crate::{Result, Status}; +use std::ffi::CString; +use tensorflow_sys as tf; + +/// PluggableDeviceLibrary handler. +#[derive(Debug)] +pub struct PluggableDeviceLibrary { + inner: *mut tf::TF_Library, +} + +impl PluggableDeviceLibrary { + /// Load the library specified by library_filename and register the pluggable + /// device and related kernels present in that library. This function is not + /// supported on embedded on mobile and embedded platforms and will fail if + /// called. + /// + /// Pass "library_filename" to a platform-specific mechanism for dynamically + /// loading a library. The rules for determining the exact location of the + /// library are platform-specific and are not documented here. + pub fn load(library_filename: &str) -> Result { + let status = Status::new(); + let library_filename = CString::new(library_filename)?; + let lib_handle = + unsafe { tf::TF_LoadPluggableDeviceLibrary(library_filename.as_ptr(), status.inner) }; + status.into_result()?; + + Ok(PluggableDeviceLibrary { inner: lib_handle }) + } +} + +impl Drop for PluggableDeviceLibrary { + /// Frees the memory associated with the library handle. + /// Does NOT unload the library. + fn drop(&mut self) { + unsafe { + tf::TF_DeletePluggableDeviceLibraryHandle(self.inner); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[ignore] + #[test] + fn load_pluggable_device_library() { + let library_filename = "path-to-library"; + let pluggable_divice_library = PluggableDeviceLibrary::load(library_filename); + dbg!(&pluggable_divice_library); + assert!((pluggable_divice_library.is_ok())); + } +} diff --git a/tensorflow-sys/Cargo.toml b/tensorflow-sys/Cargo.toml index 9e4ff2d280..91eb378b14 100644 --- a/tensorflow-sys/Cargo.toml +++ b/tensorflow-sys/Cargo.toml @@ -36,6 +36,7 @@ zip = "0.6.4" [features] tensorflow_gpu = [] eager = [] +experimental = [] # This is for testing purposes; users should not use this. examples_system_alloc = [] private-docs-rs = [] # DO NOT RELY ON THIS diff --git a/tensorflow-sys/generate_bindgen_rs.sh b/tensorflow-sys/generate_bindgen_rs.sh index 6ce66390ee..b1adf5fd5e 100755 --- a/tensorflow-sys/generate_bindgen_rs.sh +++ b/tensorflow-sys/generate_bindgen_rs.sh @@ -7,14 +7,22 @@ if ! which bindgen > /dev/null; then exit 1 fi -include_dir="$HOME/git/tensorflow" +include_dir="../../tensorflow" +# Export C-API bindgen_options_c_api="--allowlist-function TF_.+ --allowlist-type TF_.+ --allowlist-var TF_.+ --size_t-is-usize --default-enum-style=rust --generate-inline-functions" -cmd="bindgen ${bindgen_options_c_api} ${include_dir}/tensorflow/c/c_api.h --output src/c_api.rs -- -I ${include_dir}" +cmd="bindgen ${bindgen_options_c_api} ${include_dir}/tensorflow/c/c_api.h --output src/c_api.rs -- -I ${include_dir}" echo ${cmd} ${cmd} -bindgen_options_eager="--allowlist-function TFE_.+ --allowlist-type TFE_.+ --allowlist-var TFE_.+ --blocklist-type TF_.+ --size_t-is-usize --default-enum-style=rust --generate-inline-functions" -cmd="bindgen ${bindgen_options_eager} ${include_dir}/tensorflow/c/eager/c_api.h --output src/eager/c_api.rs -- -I ${include_dir}" +# Export PluggableDeviceLibrary from C-API experimental +bindgen_options_c_api_experimental="--allowlist-function TF_.+PluggableDeviceLibrary.* --blocklist-type TF_.+ --size_t-is-usize" +cmd="bindgen ${bindgen_options_c_api_experimental} ${include_dir}/tensorflow/c/c_api_experimental.h --output src/c_api_experimental.rs -- -I ${include_dir}" +echo ${cmd} +${cmd} + +# Export Eager C-API +bindgen_options_eager="--allowlist-function TFE_.+ --allowlist-type TFE_.+ --allowlist-var TFE_.+ --blocklist-type TF_.+ --size_t-is-usize --default-enum-style=rust --generate-inline-functions --no-layout-tests" +cmd="bindgen ${bindgen_options_eager} ${include_dir}/tensorflow/c/eager/c_api.h --output src/eager/c_api.rs -- -I ${include_dir}" echo ${cmd} ${cmd} diff --git a/tensorflow-sys/src/c_api_experimental.rs b/tensorflow-sys/src/c_api_experimental.rs new file mode 100644 index 0000000000..983e285bea --- /dev/null +++ b/tensorflow-sys/src/c_api_experimental.rs @@ -0,0 +1,11 @@ +/* automatically generated by rust-bindgen 0.59.1 */ + +extern "C" { + pub fn TF_LoadPluggableDeviceLibrary( + library_filename: *const ::std::os::raw::c_char, + status: *mut TF_Status, + ) -> *mut TF_Library; +} +extern "C" { + pub fn TF_DeletePluggableDeviceLibraryHandle(lib_handle: *mut TF_Library); +} diff --git a/tensorflow-sys/src/lib.rs b/tensorflow-sys/src/lib.rs index 2e3f1999d7..8b55aa6d11 100644 --- a/tensorflow-sys/src/lib.rs +++ b/tensorflow-sys/src/lib.rs @@ -7,6 +7,8 @@ mod eager; #[cfg(feature = "eager")] pub use eager::*; include!("c_api.rs"); +#[cfg(feature = "experimental")] +include!("c_api_experimental.rs"); pub use crate::TF_AttrType::*; pub use crate::TF_Code::*; diff --git a/test-all b/test-all index dbf67ab0a5..66824f16c9 100755 --- a/test-all +++ b/test-all @@ -58,8 +58,8 @@ run cargo run --example regression run cargo run --example xor run cargo run --features tensorflow_unstable --example expressions run cargo run --features eager --example mobilenetv3 -run cargo doc -vv --features tensorflow_unstable,ndarray,eager -run cargo doc -vv --features tensorflow_unstable,ndarray,eager,private-docs-rs +run cargo doc -vv --features experimental,tensorflow_unstable,ndarray,eager +run cargo doc -vv --features experimental,tensorflow_unstable,ndarray,eager,private-docs-rs # TODO(#66): Re-enable: (cd tensorflow-sys && cargo test -vv -j 1) (cd tensorflow-sys && run cargo run --example multiplication) (cd tensorflow-sys && run cargo run --example tf_version)