Skip to content

Commit 81456d3

Browse files
authored
Merge pull request #90 from Enet4/savedmodelbundle
Add SavedModelBundle API
2 parents 73ed289 + 55a41e1 commit 81456d3

File tree

5 files changed

+101
-8
lines changed

5 files changed

+101
-8
lines changed

examples/regression_savedmodel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
1212
b = tf.Variable(tf.zeros([1]), name='b')
13-
y_hat = w * x + b
13+
y_hat = tf.add(w * x, b, name="y_hat")
1414

1515
loss = tf.reduce_mean(tf.square(y_hat - y))
1616
optimizer = tf.train.GradientDescentOptimizer(0.5)

src/session.rs

+100-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::ffi::CString;
44
use std::marker;
55
use std::path::Path;
66
use std::ptr;
7+
use super::{Buffer, BufferTrait};
78
use super::Code;
89
use super::DataType;
910
use super::Graph;
@@ -16,6 +17,64 @@ use super::Status;
1617
use super::Tensor;
1718
use super::TensorType;
1819

20+
/// Aggregation type for a saved model bundle.
21+
#[derive(Debug)]
22+
pub struct SavedModelBundle {
23+
/// The loaded session.
24+
pub session: Session,
25+
/// A meta graph definition as raw protocol buffer.
26+
pub meta_graph_def: Vec<u8>,
27+
}
28+
29+
impl SavedModelBundle {
30+
31+
/// Loads a session from an exported model, creating a bundle
32+
pub fn load<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>
33+
(options: &SessionOptions,
34+
tags: Tags,
35+
graph: &mut Graph,
36+
export_dir: P)
37+
-> Result<SavedModelBundle> {
38+
let mut status = Status::new();
39+
40+
let export_dir_cstr =
41+
export_dir.as_ref()
42+
.to_str()
43+
.and_then(|s| CString::new(s.as_bytes()).ok())
44+
.ok_or_else(|| invalid_arg!("Invalid export directory path"))?;
45+
46+
let tags_cstr: Vec<_> = tags.into_iter()
47+
.map(|t| CString::new(t.as_ref()))
48+
.collect::<::std::result::Result<_, _>>()
49+
.map_err(|_| invalid_arg!("Invalid tag name"))?;
50+
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();
51+
52+
// The empty TF_Buffer will be filled by LoadSessionFromSavedModel
53+
let mut meta = unsafe { Buffer::<u8>::from_ptr(ptr::null_mut(), 0) };
54+
55+
let inner = unsafe {
56+
tf::TF_LoadSessionFromSavedModel(options.inner,
57+
ptr::null(),
58+
export_dir_cstr.as_ptr(),
59+
tags_ptr.as_ptr(),
60+
tags_ptr.len() as c_int,
61+
graph.inner(),
62+
meta.inner_mut(),
63+
status.inner())
64+
};
65+
if inner.is_null() {
66+
Err(status)
67+
} else {
68+
let session = Session { inner: inner };
69+
Ok(SavedModelBundle {
70+
session: session,
71+
meta_graph_def: Vec::from(meta.as_ref())
72+
})
73+
}
74+
}
75+
76+
}
77+
1978
/// Manages a single graph and execution.
2079
#[derive(Debug)]
2180
pub struct Session {
@@ -43,16 +102,15 @@ impl Session {
43102
-> Result<Self> {
44103
let mut status = Status::new();
45104

46-
let export_dir_cstr =
47-
try!(export_dir.as_ref()
105+
let export_dir_cstr = export_dir.as_ref()
48106
.to_str()
49107
.and_then(|s| CString::new(s.as_bytes()).ok())
50-
.ok_or_else(|| invalid_arg!("Invalid export directory path")));
108+
.ok_or_else(|| invalid_arg!("Invalid export directory path"))?;
51109

52-
let tags_cstr: Vec<_> = try!(tags.into_iter()
53-
.map(|t| CString::new(t.as_ref()))
54-
.collect::<::std::result::Result<_, _>>()
55-
.map_err(|_| invalid_arg!("Invalid tag name")));
110+
let tags_cstr: Vec<_> = tags.into_iter()
111+
.map(|t| CString::new(t.as_ref()))
112+
.collect::<::std::result::Result<_, _>>()
113+
.map_err(|_| invalid_arg!("Invalid tag name"))?;
56114
// keeping tags_cstr to retain strings in memory
57115
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();
58116

@@ -326,4 +384,39 @@ mod tests {
326384
assert_eq!(output_tensor[0], 4.0);
327385
assert_eq!(output_tensor[1], 6.0);
328386
}
387+
388+
#[test]
389+
fn test_savedmodelbundle() {
390+
let mut graph = Graph::new();
391+
let bundle = SavedModelBundle::load(
392+
&SessionOptions::new(),
393+
&["train", "serve"],
394+
&mut graph,
395+
"test_resources/regression-model",
396+
).unwrap();
397+
398+
let x_op = graph.operation_by_name_required("x").unwrap();
399+
let y_op = graph.operation_by_name_required("y").unwrap();
400+
let y_hat_op = graph.operation_by_name_required("y_hat").unwrap();
401+
let _train_op = graph.operation_by_name_required("train").unwrap();
402+
403+
let SavedModelBundle {
404+
mut session,
405+
meta_graph_def,
406+
} = bundle;
407+
408+
assert!(!meta_graph_def.is_empty());
409+
410+
let mut x = <Tensor<f32>>::new(&[1]);
411+
x[0] = 2.0;
412+
let mut y = <Tensor<f32>>::new(&[1]);
413+
y[0] = 4.0;
414+
let mut step = StepWithGraph::new();
415+
step.add_input(&x_op, 0, &x);
416+
step.add_input(&y_op, 0, &y);
417+
let output_token = step.request_output(&y_hat_op, 0);
418+
session.run(&mut step).unwrap();
419+
let output_tensor = step.take_output::<f32>(output_token).unwrap();
420+
assert_eq!(output_tensor.len(), 1);
421+
}
329422
}
17.6 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)