Skip to content

Commit 43f6bf9

Browse files
committed
move from_saved_model_to_bundle to SavedModelBundle::load
1 parent dd2e334 commit 43f6bf9

File tree

1 file changed

+38
-34
lines changed

1 file changed

+38
-34
lines changed

src/session.rs

+38-34
Original file line numberDiff line numberDiff line change
@@ -26,31 +26,15 @@ pub struct SavedModelBundle {
2626
pub meta_graph_def: Vec<u8>,
2727
}
2828

29-
/// Manages a single graph and execution.
30-
#[derive(Debug)]
31-
pub struct Session {
32-
inner: *mut tf::TF_Session,
33-
}
34-
35-
impl Session {
36-
/// Creates a session.
37-
pub fn new(options: &SessionOptions, graph: &Graph) -> Result<Self> {
38-
let mut status = Status::new();
39-
let inner = unsafe { tf::TF_NewSession(graph.inner(), options.inner, status.inner()) };
40-
if inner.is_null() {
41-
Err(status)
42-
} else {
43-
Ok(Session { inner: inner })
44-
}
45-
}
29+
impl SavedModelBundle {
4630

47-
/// Loads a session from an exported model.
48-
pub fn from_saved_model<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>
31+
/// Loads a session from an exported model, creating a bundle
32+
pub fn load<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>
4933
(options: &SessionOptions,
5034
tags: Tags,
5135
graph: &mut Graph,
5236
export_dir: P)
53-
-> Result<Self> {
37+
-> Result<SavedModelBundle> {
5438
let mut status = Status::new();
5539

5640
let export_dir_cstr =
@@ -63,33 +47,59 @@ impl Session {
6347
.map(|t| CString::new(t.as_ref()))
6448
.collect::<::std::result::Result<_, _>>()
6549
.map_err(|_| invalid_arg!("Invalid tag name")));
66-
// keeping tags_cstr to retain strings in memory
6750
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();
6851

52+
// The empty TF_Buffer will be filled by LoadSessionFromSavedModel
53+
let mut meta = unsafe { Buffer::<u8>::from_ptr(ptr::null_mut(), 0) };
54+
6955
let inner = unsafe {
7056
tf::TF_LoadSessionFromSavedModel(options.inner,
7157
ptr::null(),
7258
export_dir_cstr.as_ptr(),
7359
tags_ptr.as_ptr(),
7460
tags_ptr.len() as c_int,
7561
graph.inner(),
76-
ptr::null_mut(),
62+
meta.inner_mut(),
7763
status.inner())
7864
};
65+
if inner.is_null() {
66+
Err(status)
67+
} else {
68+
let session = Session { inner: inner };
69+
Ok(SavedModelBundle {
70+
session: session,
71+
meta_graph_def: Vec::from(meta.as_ref())
72+
})
73+
}
74+
}
75+
76+
}
77+
78+
/// Manages a single graph and execution.
79+
#[derive(Debug)]
80+
pub struct Session {
81+
inner: *mut tf::TF_Session,
82+
}
83+
84+
impl Session {
85+
/// Creates a session.
86+
pub fn new(options: &SessionOptions, graph: &Graph) -> Result<Self> {
87+
let mut status = Status::new();
88+
let inner = unsafe { tf::TF_NewSession(graph.inner(), options.inner, status.inner()) };
7989
if inner.is_null() {
8090
Err(status)
8191
} else {
8292
Ok(Session { inner: inner })
8393
}
8494
}
8595

86-
/// Loads a session from an exported model, creating a bundle
87-
pub fn from_saved_model_to_bundle<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>
96+
/// Loads a session from an exported model.
97+
pub fn from_saved_model<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>
8898
(options: &SessionOptions,
8999
tags: Tags,
90100
graph: &mut Graph,
91101
export_dir: P)
92-
-> Result<SavedModelBundle> {
102+
-> Result<Self> {
93103
let mut status = Status::new();
94104

95105
let export_dir_cstr =
@@ -102,29 +112,23 @@ impl Session {
102112
.map(|t| CString::new(t.as_ref()))
103113
.collect::<::std::result::Result<_, _>>()
104114
.map_err(|_| invalid_arg!("Invalid tag name")));
115+
// keeping tags_cstr to retain strings in memory
105116
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();
106117

107-
// The empty TF_Buffer will be filled by LoadSessionFromSavedModel
108-
let mut meta = unsafe { Buffer::<u8>::from_ptr(ptr::null_mut(), 0) };
109-
110118
let inner = unsafe {
111119
tf::TF_LoadSessionFromSavedModel(options.inner,
112120
ptr::null(),
113121
export_dir_cstr.as_ptr(),
114122
tags_ptr.as_ptr(),
115123
tags_ptr.len() as c_int,
116124
graph.inner(),
117-
meta.inner_mut(),
125+
ptr::null_mut(),
118126
status.inner())
119127
};
120128
if inner.is_null() {
121129
Err(status)
122130
} else {
123-
let session = Session { inner: inner };
124-
Ok(SavedModelBundle {
125-
session: session,
126-
meta_graph_def: Vec::from(meta.as_ref())
127-
})
131+
Ok(Session { inner: inner })
128132
}
129133
}
130134

0 commit comments

Comments
 (0)