@@ -4,6 +4,7 @@ use std::ffi::CString;
44use std:: marker;
55use std:: path:: Path ;
66use std:: ptr;
7+ use super :: { Buffer , BufferTrait } ;
78use super :: Code ;
89use super :: DataType ;
910use super :: Graph ;
@@ -16,6 +17,15 @@ use super::Status;
1617use super :: Tensor ;
1718use super :: TensorType ;
1819
20+ /// Aggregation type for a saved model bundle.
21+ #[ derive( Debug ) ]
22+ pub struct SavedModelBundle {
23+ /// The loaded session.
24+ pub session : Session ,
25+ /// A meta graph defition as raw protocol buffer.
26+ pub meta_graph_def : Vec < u8 > ,
27+ }
28+
1929/// Manages a single graph and execution.
2030#[ derive( Debug ) ]
2131pub struct Session {
@@ -73,6 +83,51 @@ impl Session {
7383 }
7484 }
7585
86+ /// Loads a session from an exported model, creating a bundle
87+ pub fn from_saved_model_to_bundle < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
88+ ( options : & SessionOptions ,
89+ tags : Tags ,
90+ graph : & mut Graph ,
91+ export_dir : P )
92+ -> Result < SavedModelBundle > {
93+ let mut status = Status :: new ( ) ;
94+
95+ let export_dir_cstr =
96+ try!( export_dir. as_ref ( )
97+ . to_str ( )
98+ . and_then ( |s| CString :: new ( s. as_bytes ( ) ) . ok ( ) )
99+ . ok_or_else ( || invalid_arg ! ( "Invalid export directory path" ) ) ) ;
100+
101+ let tags_cstr: Vec < _ > = try!( tags. into_iter ( )
102+ . map ( |t| CString :: new ( t. as_ref ( ) ) )
103+ . collect :: < :: std:: result:: Result < _ , _ > > ( )
104+ . map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ) ;
105+ let tags_ptr: Vec < * const c_char > = tags_cstr. iter ( ) . map ( |t| t. as_ptr ( ) ) . collect ( ) ;
106+
107+ // The empty TF_Buffer will be filled by LoadSessionFromSavedModel
108+ let mut meta = unsafe { Buffer :: < u8 > :: from_ptr ( ptr:: null_mut ( ) , 0 ) } ;
109+
110+ let inner = unsafe {
111+ tf:: TF_LoadSessionFromSavedModel ( options. inner ,
112+ ptr:: null ( ) ,
113+ export_dir_cstr. as_ptr ( ) ,
114+ tags_ptr. as_ptr ( ) ,
115+ tags_ptr. len ( ) as c_int ,
116+ graph. inner ( ) ,
117+ meta. inner_mut ( ) ,
118+ status. inner ( ) )
119+ } ;
120+ if inner. is_null ( ) {
121+ Err ( status)
122+ } else {
123+ let session = Session { inner : inner } ;
124+ Ok ( SavedModelBundle {
125+ session : session,
126+ meta_graph_def : Vec :: from ( meta. as_ref ( ) )
127+ } )
128+ }
129+ }
130+
76131 /// Closes the session.
77132 pub fn close ( & mut self ) -> Result < ( ) > {
78133 let mut status = Status :: new ( ) ;
0 commit comments