@@ -4,6 +4,7 @@ use std::ffi::CString;
44use std:: marker;
55use std:: path:: Path ;
66use std:: ptr;
7+ use super :: { Buffer , BufferTrait } ;
78use super :: Code ;
89use super :: DataType ;
910use super :: Graph ;
@@ -16,6 +17,64 @@ use super::Status;
1617use super :: Tensor ;
1718use super :: TensorType ;
1819
20+ /// Aggregation type for a saved model bundle.
21+ #[ derive( Debug ) ]
22+ pub struct SavedModelBundle {
23+ /// The loaded session.
24+ pub session : Session ,
25+ /// A meta graph definition as raw protocol buffer.
26+ pub meta_graph_def : Vec < u8 > ,
27+ }
28+
29+ impl SavedModelBundle {
30+
31+ /// Loads a session from an exported model, creating a bundle
32+ pub fn load < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
33+ ( options : & SessionOptions ,
34+ tags : Tags ,
35+ graph : & mut Graph ,
36+ export_dir : P )
37+ -> Result < SavedModelBundle > {
38+ let mut status = Status :: new ( ) ;
39+
40+ let export_dir_cstr =
41+ export_dir. as_ref ( )
42+ . to_str ( )
43+ . and_then ( |s| CString :: new ( s. as_bytes ( ) ) . ok ( ) )
44+ . ok_or_else ( || invalid_arg ! ( "Invalid export directory path" ) ) ?;
45+
46+ let tags_cstr: Vec < _ > = tags. into_iter ( )
47+ . map ( |t| CString :: new ( t. as_ref ( ) ) )
48+ . collect :: < :: std:: result:: Result < _ , _ > > ( )
49+ . map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ?;
50+ let tags_ptr: Vec < * const c_char > = tags_cstr. iter ( ) . map ( |t| t. as_ptr ( ) ) . collect ( ) ;
51+
52+ // The empty TF_Buffer will be filled by LoadSessionFromSavedModel
53+ let mut meta = unsafe { Buffer :: < u8 > :: from_ptr ( ptr:: null_mut ( ) , 0 ) } ;
54+
55+ let inner = unsafe {
56+ tf:: TF_LoadSessionFromSavedModel ( options. inner ,
57+ ptr:: null ( ) ,
58+ export_dir_cstr. as_ptr ( ) ,
59+ tags_ptr. as_ptr ( ) ,
60+ tags_ptr. len ( ) as c_int ,
61+ graph. inner ( ) ,
62+ meta. inner_mut ( ) ,
63+ status. inner ( ) )
64+ } ;
65+ if inner. is_null ( ) {
66+ Err ( status)
67+ } else {
68+ let session = Session { inner : inner } ;
69+ Ok ( SavedModelBundle {
70+ session : session,
71+ meta_graph_def : Vec :: from ( meta. as_ref ( ) )
72+ } )
73+ }
74+ }
75+
76+ }
77+
1978/// Manages a single graph and execution.
2079#[ derive( Debug ) ]
2180pub struct Session {
@@ -43,16 +102,15 @@ impl Session {
43102 -> Result < Self > {
44103 let mut status = Status :: new ( ) ;
45104
46- let export_dir_cstr =
47- try!( export_dir. as_ref ( )
105+ let export_dir_cstr = export_dir. as_ref ( )
48106 . to_str ( )
49107 . and_then ( |s| CString :: new ( s. as_bytes ( ) ) . ok ( ) )
50- . ok_or_else ( || invalid_arg ! ( "Invalid export directory path" ) ) ) ;
108+ . ok_or_else ( || invalid_arg ! ( "Invalid export directory path" ) ) ? ;
51109
52- let tags_cstr: Vec < _ > = try! ( tags. into_iter ( )
53- . map ( |t| CString :: new ( t. as_ref ( ) ) )
54- . collect :: < :: std:: result:: Result < _ , _ > > ( )
55- . map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ) ;
110+ let tags_cstr: Vec < _ > = tags. into_iter ( )
111+ . map ( |t| CString :: new ( t. as_ref ( ) ) )
112+ . collect :: < :: std:: result:: Result < _ , _ > > ( )
113+ . map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ? ;
56114 // keeping tags_cstr to retain strings in memory
57115 let tags_ptr: Vec < * const c_char > = tags_cstr. iter ( ) . map ( |t| t. as_ptr ( ) ) . collect ( ) ;
58116
@@ -326,4 +384,39 @@ mod tests {
326384 assert_eq ! ( output_tensor[ 0 ] , 4.0 ) ;
327385 assert_eq ! ( output_tensor[ 1 ] , 6.0 ) ;
328386 }
387+
388+ #[ test]
389+ fn test_savedmodelbundle ( ) {
390+ let mut graph = Graph :: new ( ) ;
391+ let bundle = SavedModelBundle :: load (
392+ & SessionOptions :: new ( ) ,
393+ & [ "train" , "serve" ] ,
394+ & mut graph,
395+ "test_resources/regression-model" ,
396+ ) . unwrap ( ) ;
397+
398+ let x_op = graph. operation_by_name_required ( "x" ) . unwrap ( ) ;
399+ let y_op = graph. operation_by_name_required ( "y" ) . unwrap ( ) ;
400+ let y_hat_op = graph. operation_by_name_required ( "y_hat" ) . unwrap ( ) ;
401+ let _train_op = graph. operation_by_name_required ( "train" ) . unwrap ( ) ;
402+
403+ let SavedModelBundle {
404+ mut session,
405+ meta_graph_def,
406+ } = bundle;
407+
408+ assert ! ( !meta_graph_def. is_empty( ) ) ;
409+
410+ let mut x = <Tensor < f32 > >:: new ( & [ 1 ] ) ;
411+ x[ 0 ] = 2.0 ;
412+ let mut y = <Tensor < f32 > >:: new ( & [ 1 ] ) ;
413+ y[ 0 ] = 4.0 ;
414+ let mut step = StepWithGraph :: new ( ) ;
415+ step. add_input ( & x_op, 0 , & x) ;
416+ step. add_input ( & y_op, 0 , & y) ;
417+ let output_token = step. request_output ( & y_hat_op, 0 ) ;
418+ session. run ( & mut step) . unwrap ( ) ;
419+ let output_tensor = step. take_output :: < f32 > ( output_token) . unwrap ( ) ;
420+ assert_eq ! ( output_tensor. len( ) , 1 ) ;
421+ }
329422}
0 commit comments