Skip to content

Load from saved model support #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 11, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion src/session.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use tf;
use libc::c_int;
use libc::{c_char, c_int};
use std::ffi::CString;
use std::marker;
use std::path::Path;
use std::ptr;
use std::result::Result as StdResult;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm I find this confusing, but if @adamcrume is good with it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I just needed a hint on line 56 to pluck the result out, and this approach sounded elegant to me. We could consider something else though...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other places we're naming std::result::Result directly without importing it, like so:

pub type Result<T> = std::result::Result<T, Status>;

I'd prefer that, just for consistency. On a side note, I regret adding that type alias. I assumed it was an accepted pattern because of std::fmt::Result, std::io::Result, and std::thread::Result, but having multiple types with the same name (just in different modules) causes no end of headaches.

use super::Code;
use super::DataType;
use super::Graph;
Expand Down Expand Up @@ -32,6 +35,43 @@ impl Session {
}
}

/// Loads a session from an exported model.
pub fn from_saved_model<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item=Tag>>(
options: &SessionOptions, tags: Tags, graph: &mut Graph, export_dir: P) -> Result<Self> {
let mut status = Status::new();

let export_dir_cstr = try!(export_dir.as_ref().to_str()
.and_then(|s| CString::new(s.as_bytes()).ok())
.ok_or_else(|| Status::new_set(
Code::InvalidArgument, "Invalid export directory path").unwrap()));

let tags_cstr: Vec<_> = try!(tags.into_iter()
.map(|t| CString::new(t.as_ref()))
.collect::<StdResult<_,_>>()
.map_err(|_| Status::new_set(Code::InvalidArgument, "Invalid tag name").unwrap()));
// keeping tags_cstr to retain strings in memory
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();

let inner = unsafe {
tf::TF_LoadSessionFromSavedModel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting looks a bit off here, can you run it through rustfmt please?

options.inner,
ptr::null(),
export_dir_cstr.to_bytes_with_nul().as_ptr() as *const c_char,
tags_ptr.as_ptr(),
tags_ptr.len() as c_int,
graph.inner(),
ptr::null_mut(),
status.inner())
};
if inner.is_null() {
Err(status)
} else {
Ok(Session {
inner: inner,
})
}
}

/// Closes the session.
pub fn close(&mut self) -> Result<()> {
let mut status = Status::new();
Expand Down
8 changes: 8 additions & 0 deletions tensorflow-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,14 @@ extern "C" {
opts: *const TF_SessionOptions,
status: *mut TF_Status)
-> *mut TF_Session;
pub fn TF_LoadSessionFromSavedModel(session_options: *const TF_SessionOptions,
run_options: *const TF_Buffer,
export_dir: *const c_char,
tags: *const *const c_char,
tags_len: c_int,
graph: *mut TF_Graph,
meta_graph_def: *mut TF_Buffer,
status: *mut TF_Status) -> *mut TF_Session;
pub fn TF_CloseSession(session: *mut TF_Session, status: *mut TF_Status);
pub fn TF_DeleteSession(session: *mut TF_Session, status: *mut TF_Status);
pub fn TF_SessionRun(session: *mut TF_Session,
Expand Down