@@ -26,31 +26,15 @@ pub struct SavedModelBundle {
2626 pub meta_graph_def : Vec < u8 > ,
2727}
2828
29- /// Manages a single graph and execution.
30- #[ derive( Debug ) ]
31- pub struct Session {
32- inner : * mut tf:: TF_Session ,
33- }
34-
35- impl Session {
36- /// Creates a session.
37- pub fn new ( options : & SessionOptions , graph : & Graph ) -> Result < Self > {
38- let mut status = Status :: new ( ) ;
39- let inner = unsafe { tf:: TF_NewSession ( graph. inner ( ) , options. inner , status. inner ( ) ) } ;
40- if inner. is_null ( ) {
41- Err ( status)
42- } else {
43- Ok ( Session { inner : inner } )
44- }
45- }
29+ impl SavedModelBundle {
4630
47- /// Loads a session from an exported model.
48- pub fn from_saved_model < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
31+ /// Loads a session from an exported model, creating a bundle
32+ pub fn load < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
4933 ( options : & SessionOptions ,
5034 tags : Tags ,
5135 graph : & mut Graph ,
5236 export_dir : P )
53- -> Result < Self > {
37+ -> Result < SavedModelBundle > {
5438 let mut status = Status :: new ( ) ;
5539
5640 let export_dir_cstr =
@@ -63,33 +47,59 @@ impl Session {
6347 . map ( |t| CString :: new ( t. as_ref ( ) ) )
6448 . collect :: < :: std:: result:: Result < _ , _ > > ( )
6549 . map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ) ;
66- // keeping tags_cstr to retain strings in memory
6750 let tags_ptr: Vec < * const c_char > = tags_cstr. iter ( ) . map ( |t| t. as_ptr ( ) ) . collect ( ) ;
6851
52+ // The empty TF_Buffer will be filled by LoadSessionFromSavedModel
53+ let mut meta = unsafe { Buffer :: < u8 > :: from_ptr ( ptr:: null_mut ( ) , 0 ) } ;
54+
6955 let inner = unsafe {
7056 tf:: TF_LoadSessionFromSavedModel ( options. inner ,
7157 ptr:: null ( ) ,
7258 export_dir_cstr. as_ptr ( ) ,
7359 tags_ptr. as_ptr ( ) ,
7460 tags_ptr. len ( ) as c_int ,
7561 graph. inner ( ) ,
76- ptr :: null_mut ( ) ,
62+ meta . inner_mut ( ) ,
7763 status. inner ( ) )
7864 } ;
65+ if inner. is_null ( ) {
66+ Err ( status)
67+ } else {
68+ let session = Session { inner : inner } ;
69+ Ok ( SavedModelBundle {
70+ session : session,
71+ meta_graph_def : Vec :: from ( meta. as_ref ( ) )
72+ } )
73+ }
74+ }
75+
76+ }
77+
78+ /// Manages a single graph and execution.
79+ #[ derive( Debug ) ]
80+ pub struct Session {
81+ inner : * mut tf:: TF_Session ,
82+ }
83+
84+ impl Session {
85+ /// Creates a session.
86+ pub fn new ( options : & SessionOptions , graph : & Graph ) -> Result < Self > {
87+ let mut status = Status :: new ( ) ;
88+ let inner = unsafe { tf:: TF_NewSession ( graph. inner ( ) , options. inner , status. inner ( ) ) } ;
7989 if inner. is_null ( ) {
8090 Err ( status)
8191 } else {
8292 Ok ( Session { inner : inner } )
8393 }
8494 }
8595
86- /// Loads a session from an exported model, creating a bundle
87- pub fn from_saved_model_to_bundle < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
96+ /// Loads a session from an exported model.
97+ pub fn from_saved_model < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
8898 ( options : & SessionOptions ,
8999 tags : Tags ,
90100 graph : & mut Graph ,
91101 export_dir : P )
92- -> Result < SavedModelBundle > {
102+ -> Result < Self > {
93103 let mut status = Status :: new ( ) ;
94104
95105 let export_dir_cstr =
@@ -102,29 +112,23 @@ impl Session {
102112 . map ( |t| CString :: new ( t. as_ref ( ) ) )
103113 . collect :: < :: std:: result:: Result < _ , _ > > ( )
104114 . map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ) ;
115+ // keeping tags_cstr to retain strings in memory
105116 let tags_ptr: Vec < * const c_char > = tags_cstr. iter ( ) . map ( |t| t. as_ptr ( ) ) . collect ( ) ;
106117
107- // The empty TF_Buffer will be filled by LoadSessionFromSavedModel
108- let mut meta = unsafe { Buffer :: < u8 > :: from_ptr ( ptr:: null_mut ( ) , 0 ) } ;
109-
110118 let inner = unsafe {
111119 tf:: TF_LoadSessionFromSavedModel ( options. inner ,
112120 ptr:: null ( ) ,
113121 export_dir_cstr. as_ptr ( ) ,
114122 tags_ptr. as_ptr ( ) ,
115123 tags_ptr. len ( ) as c_int ,
116124 graph. inner ( ) ,
117- meta . inner_mut ( ) ,
125+ ptr :: null_mut ( ) ,
118126 status. inner ( ) )
119127 } ;
120128 if inner. is_null ( ) {
121129 Err ( status)
122130 } else {
123- let session = Session { inner : inner } ;
124- Ok ( SavedModelBundle {
125- session : session,
126- meta_graph_def : Vec :: from ( meta. as_ref ( ) )
127- } )
131+ Ok ( Session { inner : inner } )
128132 }
129133 }
130134
0 commit comments