@@ -4,6 +4,7 @@ use std::ffi::CString;
4
4
use std:: marker;
5
5
use std:: path:: Path ;
6
6
use std:: ptr;
7
+ use super :: { Buffer , BufferTrait } ;
7
8
use super :: Code ;
8
9
use super :: DataType ;
9
10
use super :: Graph ;
@@ -16,6 +17,64 @@ use super::Status;
16
17
use super :: Tensor ;
17
18
use super :: TensorType ;
18
19
20
+ /// Aggregation type for a saved model bundle.
21
+ #[ derive( Debug ) ]
22
+ pub struct SavedModelBundle {
23
+ /// The loaded session.
24
+ pub session : Session ,
25
+ /// A meta graph definition as raw protocol buffer.
26
+ pub meta_graph_def : Vec < u8 > ,
27
+ }
28
+
29
+ impl SavedModelBundle {
30
+
31
+ /// Loads a session from an exported model, creating a bundle
32
+ pub fn load < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
33
+ ( options : & SessionOptions ,
34
+ tags : Tags ,
35
+ graph : & mut Graph ,
36
+ export_dir : P )
37
+ -> Result < SavedModelBundle > {
38
+ let mut status = Status :: new ( ) ;
39
+
40
+ let export_dir_cstr =
41
+ export_dir. as_ref ( )
42
+ . to_str ( )
43
+ . and_then ( |s| CString :: new ( s. as_bytes ( ) ) . ok ( ) )
44
+ . ok_or_else ( || invalid_arg ! ( "Invalid export directory path" ) ) ?;
45
+
46
+ let tags_cstr: Vec < _ > = tags. into_iter ( )
47
+ . map ( |t| CString :: new ( t. as_ref ( ) ) )
48
+ . collect :: < :: std:: result:: Result < _ , _ > > ( )
49
+ . map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ?;
50
+ let tags_ptr: Vec < * const c_char > = tags_cstr. iter ( ) . map ( |t| t. as_ptr ( ) ) . collect ( ) ;
51
+
52
+ // The empty TF_Buffer will be filled by LoadSessionFromSavedModel
53
+ let mut meta = unsafe { Buffer :: < u8 > :: from_ptr ( ptr:: null_mut ( ) , 0 ) } ;
54
+
55
+ let inner = unsafe {
56
+ tf:: TF_LoadSessionFromSavedModel ( options. inner ,
57
+ ptr:: null ( ) ,
58
+ export_dir_cstr. as_ptr ( ) ,
59
+ tags_ptr. as_ptr ( ) ,
60
+ tags_ptr. len ( ) as c_int ,
61
+ graph. inner ( ) ,
62
+ meta. inner_mut ( ) ,
63
+ status. inner ( ) )
64
+ } ;
65
+ if inner. is_null ( ) {
66
+ Err ( status)
67
+ } else {
68
+ let session = Session { inner : inner } ;
69
+ Ok ( SavedModelBundle {
70
+ session : session,
71
+ meta_graph_def : Vec :: from ( meta. as_ref ( ) )
72
+ } )
73
+ }
74
+ }
75
+
76
+ }
77
+
19
78
/// Manages a single graph and execution.
20
79
#[ derive( Debug ) ]
21
80
pub struct Session {
@@ -43,16 +102,15 @@ impl Session {
43
102
-> Result < Self > {
44
103
let mut status = Status :: new ( ) ;
45
104
46
- let export_dir_cstr =
47
- try!( export_dir. as_ref ( )
105
+ let export_dir_cstr = export_dir. as_ref ( )
48
106
. to_str ( )
49
107
. and_then ( |s| CString :: new ( s. as_bytes ( ) ) . ok ( ) )
50
- . ok_or_else ( || invalid_arg ! ( "Invalid export directory path" ) ) ) ;
108
+ . ok_or_else ( || invalid_arg ! ( "Invalid export directory path" ) ) ? ;
51
109
52
- let tags_cstr: Vec < _ > = try! ( tags. into_iter ( )
53
- . map ( |t| CString :: new ( t. as_ref ( ) ) )
54
- . collect :: < :: std:: result:: Result < _ , _ > > ( )
55
- . map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ) ;
110
+ let tags_cstr: Vec < _ > = tags. into_iter ( )
111
+ . map ( |t| CString :: new ( t. as_ref ( ) ) )
112
+ . collect :: < :: std:: result:: Result < _ , _ > > ( )
113
+ . map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ? ;
56
114
// keeping tags_cstr to retain strings in memory
57
115
let tags_ptr: Vec < * const c_char > = tags_cstr. iter ( ) . map ( |t| t. as_ptr ( ) ) . collect ( ) ;
58
116
@@ -326,4 +384,39 @@ mod tests {
326
384
assert_eq ! ( output_tensor[ 0 ] , 4.0 ) ;
327
385
assert_eq ! ( output_tensor[ 1 ] , 6.0 ) ;
328
386
}
387
+
388
+ #[ test]
389
+ fn test_savedmodelbundle ( ) {
390
+ let mut graph = Graph :: new ( ) ;
391
+ let bundle = SavedModelBundle :: load (
392
+ & SessionOptions :: new ( ) ,
393
+ & [ "train" , "serve" ] ,
394
+ & mut graph,
395
+ "test_resources/regression-model" ,
396
+ ) . unwrap ( ) ;
397
+
398
+ let x_op = graph. operation_by_name_required ( "x" ) . unwrap ( ) ;
399
+ let y_op = graph. operation_by_name_required ( "y" ) . unwrap ( ) ;
400
+ let y_hat_op = graph. operation_by_name_required ( "y_hat" ) . unwrap ( ) ;
401
+ let _train_op = graph. operation_by_name_required ( "train" ) . unwrap ( ) ;
402
+
403
+ let SavedModelBundle {
404
+ mut session,
405
+ meta_graph_def,
406
+ } = bundle;
407
+
408
+ assert ! ( !meta_graph_def. is_empty( ) ) ;
409
+
410
+ let mut x = <Tensor < f32 > >:: new ( & [ 1 ] ) ;
411
+ x[ 0 ] = 2.0 ;
412
+ let mut y = <Tensor < f32 > >:: new ( & [ 1 ] ) ;
413
+ y[ 0 ] = 4.0 ;
414
+ let mut step = StepWithGraph :: new ( ) ;
415
+ step. add_input ( & x_op, 0 , & x) ;
416
+ step. add_input ( & y_op, 0 , & y) ;
417
+ let output_token = step. request_output ( & y_hat_op, 0 ) ;
418
+ session. run ( & mut step) . unwrap ( ) ;
419
+ let output_tensor = step. take_output :: < f32 > ( output_token) . unwrap ( ) ;
420
+ assert_eq ! ( output_tensor. len( ) , 1 ) ;
421
+ }
329
422
}
0 commit comments