Skip to content

Commit 657b9c8

Browse files
authored
Merge pull request #159 from bekker/feature/example-checkpoint
Add an example for saving and loading checkpoints
2 parents 57151be + 2acd431 commit 657b9c8

16 files changed

+177
-8
lines changed

Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,6 @@ name = "regression"
3838

3939
[[example]]
4040
name = "regression_savedmodel"
41+
42+
[[example]]
43+
name = "regression_checkpoint"

examples/addition.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ fn main() {
3232
}
3333

3434
fn run() -> Result<(), Box<Error>> {
35-
let filename = "examples/addition-model/model.pb"; // z = x + y
35+
let filename = "examples/addition/model.pb"; // z = x + y
3636
if !Path::new(filename).exists() {
3737
return Err(Box::new(Status::new_set(Code::NotFound,
3838
&format!("Run 'python addition.py' to generate {} \

examples/addition.py renamed to examples/addition/addition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
tf.variables_initializer(tf.global_variables(), name = 'init')
99

1010
definition = tf.Session().graph_def
11-
directory = 'examples/addition-model'
11+
directory = 'examples/addition'
1212
tf.train.write_graph(definition, directory, 'model.pb', as_text=False)

examples/addition/model.pb

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
,
3+
x Placeholder*
4+
shape:*
5+
dtype0
6+
,
7+
y Placeholder*
8+
dtype0*
9+
shape:
10+

11+
zAddxy*
12+
T0
13+
14+
initNoOp"

examples/regression.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ fn main() {
3434
}
3535

3636
fn run() -> Result<(), Box<Error>> {
37-
let filename = "examples/regression-model/model.pb"; // y = w * x + b
37+
let filename = "examples/regression/model.pb"; // y = w * x + b
3838
if !Path::new(filename).exists() {
3939
return Err(Box::new(Status::new_set(Code::NotFound,
4040
&format!("Run 'python regression.py' to generate \
@@ -71,8 +71,6 @@ fn run() -> Result<(), Box<Error>> {
7171

7272
// Load the test data into the session.
7373
let mut init_step = StepWithGraph::new();
74-
init_step.add_input(&op_x, 0, &x);
75-
init_step.add_input(&op_y, 0, &y);
7674
init_step.add_target(&op_init);
7775
session.run(&mut init_step)?;
7876

File renamed without changes.

examples/regression.py renamed to examples/regression/regression.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@
1515
init = tf.variables_initializer(tf.global_variables(), name='init')
1616

1717
definition = tf.Session().graph_def
18-
directory = 'examples/regression-model'
18+
directory = 'examples/regression'
1919
tf.train.write_graph(definition, directory, 'model.pb', as_text=False)

examples/regression_checkpoint.rs

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#![cfg_attr(feature="nightly", feature(alloc_system))]
2+
#[cfg(feature="nightly")]
3+
extern crate alloc_system;
4+
extern crate random;
5+
extern crate tensorflow;
6+
7+
use random::Source;
8+
use std::error::Error;
9+
use std::fs::File;
10+
use std::io::Read;
11+
use std::result::Result;
12+
use std::path::Path;
13+
use std::process::exit;
14+
use tensorflow::Code;
15+
use tensorflow::Graph;
16+
use tensorflow::ImportGraphDefOptions;
17+
use tensorflow::Session;
18+
use tensorflow::SessionOptions;
19+
use tensorflow::Status;
20+
use tensorflow::StepWithGraph;
21+
use tensorflow::Tensor;
22+
23+
fn main() {
24+
// Putting the main code in another function serves two purposes:
25+
// 1. We can use the `?` operator.
26+
// 2. We can call exit safely, which does not run any destructors.
27+
exit(match run() {
28+
Ok(_) => 0,
29+
Err(e) => {
30+
println!("{}", e);
31+
1
32+
}
33+
})
34+
}
35+
36+
fn run() -> Result<(), Box<Error>> {
37+
let filename = "examples/regression_checkpoint/model.pb"; // y = w * x + b
38+
if !Path::new(filename).exists() {
39+
return Err(Box::new(Status::new_set(Code::NotFound,
40+
&format!("Run 'python regression_checkpoint.py' to generate \
41+
{} and try again.",
42+
filename))
43+
.unwrap()));
44+
}
45+
46+
// Generate some test data.
47+
let w = 0.1;
48+
let b = 0.3;
49+
let num_points = 100;
50+
let steps = 201;
51+
let mut rand = random::default();
52+
let mut x = Tensor::new(&[num_points as u64]);
53+
let mut y = Tensor::new(&[num_points as u64]);
54+
for i in 0..num_points {
55+
x[i] = (2.0 * rand.read::<f64>() - 1.0) as f32;
56+
y[i] = w * x[i] + b;
57+
}
58+
59+
// Load the computation graph defined by regression.py.
60+
let mut graph = Graph::new();
61+
let mut proto = Vec::new();
62+
File::open(filename)?.read_to_end(&mut proto)?;
63+
graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;
64+
let mut session = Session::new(&SessionOptions::new(), &graph)?;
65+
let op_x = graph.operation_by_name_required("x")?;
66+
let op_y = graph.operation_by_name_required("y")?;
67+
let op_init = graph.operation_by_name_required("init")?;
68+
let op_train = graph.operation_by_name_required("train")?;
69+
let op_w = graph.operation_by_name_required("w")?;
70+
let op_b = graph.operation_by_name_required("b")?;
71+
let op_file_path = graph.operation_by_name_required("save/Const")?;
72+
let op_save = graph.operation_by_name_required("save/control_dependency")?;
73+
let file_path_tensor: Tensor<String> = Tensor::from(String::from("examples/regression_checkpoint/saved.ckpt"));
74+
75+
// Load the test data into the session.
76+
let mut init_step = StepWithGraph::new();
77+
init_step.add_target(&op_init);
78+
session.run(&mut init_step)?;
79+
80+
// Train the model.
81+
let mut train_step = StepWithGraph::new();
82+
train_step.add_input(&op_x, 0, &x);
83+
train_step.add_input(&op_y, 0, &y);
84+
train_step.add_target(&op_train);
85+
for _ in 0..steps {
86+
session.run(&mut train_step)?;
87+
}
88+
89+
// Save the model.
90+
let mut step = StepWithGraph::new();
91+
step.add_input(&op_file_path, 0, &file_path_tensor);
92+
step.add_target(&op_save);
93+
session.run(&mut step)?;
94+
95+
// Initialize variables, to erase trained data.
96+
session.run(&mut init_step)?;
97+
98+
// Load the model.
99+
let op_load = graph.operation_by_name_required("save/restore_all")?;
100+
let mut step = StepWithGraph::new();
101+
step.add_input(&op_file_path, 0, &file_path_tensor);
102+
step.add_target(&op_load);
103+
session.run(&mut step)?;
104+
105+
// Grab the data out of the session.
106+
let mut output_step = StepWithGraph::new();
107+
let w_ix = output_step.request_output(&op_w, 0);
108+
let b_ix = output_step.request_output(&op_b, 0);
109+
session.run(&mut output_step)?;
110+
111+
// Check our results.
112+
let w_hat: f32 = output_step.take_output(w_ix)?[0];
113+
let b_hat: f32 = output_step.take_output(b_ix)?[0];
114+
println!("Checking w: expected {}, got {}. {}",
115+
w,
116+
w_hat,
117+
if (w - w_hat).abs() < 1e-3 {
118+
"Success!"
119+
} else {
120+
"FAIL"
121+
});
122+
println!("Checking b: expected {}, got {}. {}",
123+
b,
124+
b_hat,
125+
if (b - b_hat).abs() < 1e-3 {
126+
"Success!"
127+
} else {
128+
"FAIL"
129+
});
130+
Ok(())
131+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
saved.ckpt*
10.8 KB
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
import tensorflow as tf
3+
4+
x = tf.placeholder(tf.float32, name='x')
5+
y = tf.placeholder(tf.float32, name='y')
6+
7+
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
8+
b = tf.Variable(tf.zeros([1]), name='b')
9+
y_hat = w * x + b
10+
11+
loss = tf.reduce_mean(tf.square(y_hat - y))
12+
optimizer = tf.train.GradientDescentOptimizer(0.5)
13+
train = optimizer.minimize(loss, name='train')
14+
15+
init = tf.variables_initializer(tf.global_variables(), name='init')
16+
17+
# Declare saver ops
18+
saver = tf.train.Saver(tf.global_variables())
19+
20+
definition = tf.Session().graph_def
21+
directory = 'examples/regression_checkpoint'
22+
tf.train.write_graph(definition, directory, 'model.pb', as_text=False)

examples/regression_savedmodel.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ fn main() {
3131
}
3232

3333
fn run() -> Result<(), Box<Error>> {
34-
let export_dir = "examples/saved-regression-model"; // y = w * x + b
34+
let export_dir = "examples/regression_savedmodel"; // y = w * x + b
3535
if !Path::new(export_dir).exists() {
3636
return Err(Box::new(Status::new_set(Code::NotFound,
3737
&format!("Run 'python regression_savedmodel.py' to generate \

examples/regression_savedmodel.py renamed to examples/regression_savedmodel/regression_savedmodel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
init = tf.variables_initializer(tf.global_variables(), name='init')
2020

21-
directory = 'examples/saved-regression-model'
21+
directory = 'examples/regression_savedmodel'
2222
builder = SavedModelBuilder(directory)
2323

2424
with tf.Session(graph=tf.get_default_graph()) as sess:
18.3 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)