Skip to content

Commit 8bd62df

Browse files
authored
Merge pull request #68 from Enet4/export-saved-model
Load from saved model support
2 parents 022f4d0 + e0f4612 commit 8bd62df

File tree

3 files changed

+199
-15
lines changed

3 files changed

+199
-15
lines changed

examples/regression_savedmodel.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import tensorflow as tf
2+
from tensorflow.python.saved_model.builder import SavedModelBuilder
3+
from tensorflow.python.saved_model.signature_def_utils import build_signature_def
4+
from tensorflow.python.saved_model.signature_constants import REGRESS_METHOD_NAME
5+
from tensorflow.python.saved_model.tag_constants import TRAINING, SERVING
6+
from tensorflow.python.saved_model.utils import build_tensor_info
7+
8+
x = tf.placeholder(tf.float32, name='x')
9+
y = tf.placeholder(tf.float32, name='y')
10+
11+
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
12+
b = tf.Variable(tf.zeros([1]), name='b')
13+
y_hat = w * x + b
14+
15+
loss = tf.reduce_mean(tf.square(y_hat - y))
16+
optimizer = tf.train.GradientDescentOptimizer(0.5)
17+
train = optimizer.minimize(loss, name='train')
18+
19+
init = tf.variables_initializer(tf.global_variables(), name='init')
20+
21+
directory = 'examples/saved-regression-model'
22+
builder = SavedModelBuilder(directory)
23+
24+
with tf.Session(graph=tf.get_default_graph()) as sess:
25+
sess.run(init)
26+
27+
signature_inputs = {
28+
"x": build_tensor_info(x),
29+
"y": build_tensor_info(y)
30+
}
31+
signature_outputs = {
32+
"out": build_tensor_info(y_hat)
33+
}
34+
signature_def = build_signature_def(
35+
signature_inputs, signature_outputs,
36+
REGRESS_METHOD_NAME)
37+
builder.add_meta_graph_and_variables(
38+
sess, [TRAINING, SERVING],
39+
signature_def_map={
40+
REGRESS_METHOD_NAME: signature_def
41+
},
42+
assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS))
43+
builder.save(as_text=False)

examples/regression_savedmodel.rs

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
extern crate random;
2+
extern crate tensorflow;
3+
4+
use random::Source;
5+
use std::error::Error;
6+
use std::result::Result;
7+
use std::path::Path;
8+
use std::process::exit;
9+
use tensorflow::Code;
10+
use tensorflow::Graph;
11+
use tensorflow::Session;
12+
use tensorflow::SessionOptions;
13+
use tensorflow::Status;
14+
use tensorflow::StepWithGraph;
15+
use tensorflow::Tensor;
16+
17+
fn main() {
18+
// Putting the main code in another function serves two purposes:
19+
// 1. We can use the try! macro.
20+
// 2. We can call exit safely, which does not run any destructors.
21+
exit(match run() {
22+
Ok(_) => 0,
23+
Err(e) => {
24+
println!("{}", e);
25+
1
26+
}
27+
})
28+
}
29+
30+
fn run() -> Result<(), Box<Error>> {
31+
let export_dir = "examples/saved-regression-model"; // y = w * x + b
32+
if !Path::new(export_dir).exists() {
33+
return Err(Box::new(Status::new_set(Code::NotFound,
34+
&format!("Run 'python regression_savedmodel.py' to generate \
35+
{} and try again.",
36+
export_dir))
37+
.unwrap()));
38+
}
39+
40+
// Generate some test data.
41+
let w = 0.1;
42+
let b = 0.3;
43+
let num_points = 100;
44+
let steps = 201;
45+
let mut rand = random::default();
46+
let mut x = Tensor::new(&[num_points as u64]);
47+
let mut y = Tensor::new(&[num_points as u64]);
48+
for i in 0..num_points {
49+
x[i] = (2.0 * rand.read::<f64>() - 1.0) as f32;
50+
y[i] = w * x[i] + b;
51+
}
52+
53+
// Load the saved model exported by regression_savedmodel.py.
54+
let mut graph = Graph::new();
55+
let mut session = Session::from_saved_model(&SessionOptions::new(),
56+
&["train", "serve"],
57+
&mut graph,
58+
export_dir)?;
59+
let op_x = graph.operation_by_name_required("x")?;
60+
let op_y = graph.operation_by_name_required("y")?;
61+
let op_train = graph.operation_by_name_required("train")?;
62+
let op_w = graph.operation_by_name_required("w")?;
63+
let op_b = graph.operation_by_name_required("b")?;
64+
65+
// Train the model (e.g. for fine tuning).
66+
let mut train_step = StepWithGraph::new();
67+
train_step.add_input(&op_x, 0, &x);
68+
train_step.add_input(&op_y, 0, &y);
69+
train_step.add_target(&op_train);
70+
for _ in 0..steps {
71+
try!(session.run(&mut train_step));
72+
}
73+
74+
// Grab the data out of the session.
75+
let mut output_step = StepWithGraph::new();
76+
let w_ix = output_step.request_output(&op_w, 0);
77+
let b_ix = output_step.request_output(&op_b, 0);
78+
try!(session.run(&mut output_step));
79+
80+
// Check our results.
81+
let w_hat: f32 = try!(output_step.take_output(w_ix)).data()[0];
82+
let b_hat: f32 = try!(output_step.take_output(b_ix)).data()[0];
83+
println!("Checking w: expected {}, got {}. {}",
84+
w,
85+
w_hat,
86+
if (w - w_hat).abs() < 1e-3 {
87+
"Success!"
88+
} else {
89+
"FAIL"
90+
});
91+
println!("Checking b: expected {}, got {}. {}",
92+
b,
93+
b_hat,
94+
if (b - b_hat).abs() < 1e-3 {
95+
"Success!"
96+
} else {
97+
"FAIL"
98+
});
99+
Ok(())
100+
}

src/session.rs

+56-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use tf;
2-
use libc::c_int;
2+
use libc::{c_char, c_int};
3+
use std::ffi::CString;
34
use std::marker;
5+
use std::path::Path;
46
use std::ptr;
57
use super::Code;
68
use super::DataType;
@@ -32,6 +34,45 @@ impl Session {
3234
}
3335
}
3436

37+
/// Loads a session from an exported model.
38+
pub fn from_saved_model<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>
39+
(options: &SessionOptions,
40+
tags: Tags,
41+
graph: &mut Graph,
42+
export_dir: P)
43+
-> Result<Self> {
44+
let mut status = Status::new();
45+
46+
let export_dir_cstr =
47+
try!(export_dir.as_ref()
48+
.to_str()
49+
.and_then(|s| CString::new(s.as_bytes()).ok())
50+
.ok_or_else(|| invalid_arg!("Invalid export directory path")));
51+
52+
let tags_cstr: Vec<_> = try!(tags.into_iter()
53+
.map(|t| CString::new(t.as_ref()))
54+
.collect::<::std::result::Result<_, _>>()
55+
.map_err(|_| invalid_arg!("Invalid tag name")));
56+
// keeping tags_cstr to retain strings in memory
57+
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();
58+
59+
let inner = unsafe {
60+
tf::TF_LoadSessionFromSavedModel(options.inner,
61+
ptr::null(),
62+
export_dir_cstr.as_ptr(),
63+
tags_ptr.as_ptr(),
64+
tags_ptr.len() as c_int,
65+
graph.inner(),
66+
ptr::null_mut(),
67+
status.inner())
68+
};
69+
if inner.is_null() {
70+
Err(status)
71+
} else {
72+
Ok(Session { inner: inner })
73+
}
74+
}
75+
3576
/// Closes the session.
3677
pub fn close(&mut self) -> Result<()> {
3778
let mut status = Status::new();
@@ -143,19 +184,19 @@ impl<'l> StepWithGraph<'l> {
143184
index: c_int,
144185
tensor: &'l Tensor<T>) {
145186
self.input_ports.push(tf::TF_Output {
146-
oper: operation.inner(),
147-
index: index,
148-
});
187+
oper: operation.inner(),
188+
index: index,
189+
});
149190
self.input_tensors.push(tensor.inner);
150191
}
151192

152193
/// Requests that an output is fetched from the graph after running this step.
153194
/// Returns an index that you can then use to fetch this output from the step after running it.
154195
pub fn request_output(&mut self, operation: &Operation, index: c_int) -> OutputToken {
155196
self.output_ports.push(tf::TF_Output {
156-
oper: operation.inner(),
157-
index: index,
158-
});
197+
oper: operation.inner(),
198+
index: index,
199+
});
159200
self.output_tensors.push(ptr::null_mut());
160201
OutputToken { index: self.output_tensors.len() - 1 }
161202
}
@@ -172,13 +213,13 @@ impl<'l> StepWithGraph<'l> {
172213
{}",
173214
output_idx,
174215
self.output_tensors.len()))
175-
.unwrap());
216+
.unwrap());
176217
}
177218
if self.output_tensors[output_idx].is_null() {
178219
return Err(Status::new_set(Code::Unavailable,
179220
"Output not available. Either it was already taken, or \
180221
this step has not been sucessfully run yet.")
181-
.unwrap());
222+
.unwrap());
182223
}
183224
let actual_data_type = self.output_data_type(output_idx).unwrap();
184225
if actual_data_type != T::data_type() {
@@ -260,13 +301,13 @@ mod tests {
260301
let y = {
261302
let mut nd = g.new_operation("Mul", "y").unwrap();
262303
nd.add_input(Output {
263-
operation: &two,
264-
index: 0,
265-
});
304+
operation: &two,
305+
index: 0,
306+
});
266307
nd.add_input(Output {
267-
operation: &x,
268-
index: 0,
269-
});
308+
operation: &x,
309+
index: 0,
310+
});
270311
nd.finish().unwrap()
271312
};
272313
let options = SessionOptions::new();

0 commit comments

Comments
 (0)