@@ -26,31 +26,15 @@ pub struct SavedModelBundle {
26
26
pub meta_graph_def : Vec < u8 > ,
27
27
}
28
28
29
- /// Manages a single graph and execution.
30
- #[ derive( Debug ) ]
31
- pub struct Session {
32
- inner : * mut tf:: TF_Session ,
33
- }
34
-
35
- impl Session {
36
- /// Creates a session.
37
- pub fn new ( options : & SessionOptions , graph : & Graph ) -> Result < Self > {
38
- let mut status = Status :: new ( ) ;
39
- let inner = unsafe { tf:: TF_NewSession ( graph. inner ( ) , options. inner , status. inner ( ) ) } ;
40
- if inner. is_null ( ) {
41
- Err ( status)
42
- } else {
43
- Ok ( Session { inner : inner } )
44
- }
45
- }
29
+ impl SavedModelBundle {
46
30
47
- /// Loads a session from an exported model.
48
- pub fn from_saved_model < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
31
+ /// Loads a session from an exported model, creating a bundle
32
+ pub fn load < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
49
33
( options : & SessionOptions ,
50
34
tags : Tags ,
51
35
graph : & mut Graph ,
52
36
export_dir : P )
53
- -> Result < Self > {
37
+ -> Result < SavedModelBundle > {
54
38
let mut status = Status :: new ( ) ;
55
39
56
40
let export_dir_cstr =
@@ -63,33 +47,59 @@ impl Session {
63
47
. map ( |t| CString :: new ( t. as_ref ( ) ) )
64
48
. collect :: < :: std:: result:: Result < _ , _ > > ( )
65
49
. map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ) ;
66
- // keeping tags_cstr to retain strings in memory
67
50
let tags_ptr: Vec < * const c_char > = tags_cstr. iter ( ) . map ( |t| t. as_ptr ( ) ) . collect ( ) ;
68
51
52
+ // The empty TF_Buffer will be filled by LoadSessionFromSavedModel
53
+ let mut meta = unsafe { Buffer :: < u8 > :: from_ptr ( ptr:: null_mut ( ) , 0 ) } ;
54
+
69
55
let inner = unsafe {
70
56
tf:: TF_LoadSessionFromSavedModel ( options. inner ,
71
57
ptr:: null ( ) ,
72
58
export_dir_cstr. as_ptr ( ) ,
73
59
tags_ptr. as_ptr ( ) ,
74
60
tags_ptr. len ( ) as c_int ,
75
61
graph. inner ( ) ,
76
- ptr :: null_mut ( ) ,
62
+ meta . inner_mut ( ) ,
77
63
status. inner ( ) )
78
64
} ;
65
+ if inner. is_null ( ) {
66
+ Err ( status)
67
+ } else {
68
+ let session = Session { inner : inner } ;
69
+ Ok ( SavedModelBundle {
70
+ session : session,
71
+ meta_graph_def : Vec :: from ( meta. as_ref ( ) )
72
+ } )
73
+ }
74
+ }
75
+
76
+ }
77
+
78
+ /// Manages a single graph and execution.
79
+ #[ derive( Debug ) ]
80
+ pub struct Session {
81
+ inner : * mut tf:: TF_Session ,
82
+ }
83
+
84
+ impl Session {
85
+ /// Creates a session.
86
+ pub fn new ( options : & SessionOptions , graph : & Graph ) -> Result < Self > {
87
+ let mut status = Status :: new ( ) ;
88
+ let inner = unsafe { tf:: TF_NewSession ( graph. inner ( ) , options. inner , status. inner ( ) ) } ;
79
89
if inner. is_null ( ) {
80
90
Err ( status)
81
91
} else {
82
92
Ok ( Session { inner : inner } )
83
93
}
84
94
}
85
95
86
- /// Loads a session from an exported model, creating a bundle
87
- pub fn from_saved_model_to_bundle < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
96
+ /// Loads a session from an exported model.
97
+ pub fn from_saved_model < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
88
98
( options : & SessionOptions ,
89
99
tags : Tags ,
90
100
graph : & mut Graph ,
91
101
export_dir : P )
92
- -> Result < SavedModelBundle > {
102
+ -> Result < Self > {
93
103
let mut status = Status :: new ( ) ;
94
104
95
105
let export_dir_cstr =
@@ -102,29 +112,23 @@ impl Session {
102
112
. map ( |t| CString :: new ( t. as_ref ( ) ) )
103
113
. collect :: < :: std:: result:: Result < _ , _ > > ( )
104
114
. map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ) ;
115
+ // keeping tags_cstr to retain strings in memory
105
116
let tags_ptr: Vec < * const c_char > = tags_cstr. iter ( ) . map ( |t| t. as_ptr ( ) ) . collect ( ) ;
106
117
107
- // The empty TF_Buffer will be filled by LoadSessionFromSavedModel
108
- let mut meta = unsafe { Buffer :: < u8 > :: from_ptr ( ptr:: null_mut ( ) , 0 ) } ;
109
-
110
118
let inner = unsafe {
111
119
tf:: TF_LoadSessionFromSavedModel ( options. inner ,
112
120
ptr:: null ( ) ,
113
121
export_dir_cstr. as_ptr ( ) ,
114
122
tags_ptr. as_ptr ( ) ,
115
123
tags_ptr. len ( ) as c_int ,
116
124
graph. inner ( ) ,
117
- meta . inner_mut ( ) ,
125
+ ptr :: null_mut ( ) ,
118
126
status. inner ( ) )
119
127
} ;
120
128
if inner. is_null ( ) {
121
129
Err ( status)
122
130
} else {
123
- let session = Session { inner : inner } ;
124
- Ok ( SavedModelBundle {
125
- session : session,
126
- meta_graph_def : Vec :: from ( meta. as_ref ( ) )
127
- } )
131
+ Ok ( Session { inner : inner } )
128
132
}
129
133
}
130
134
0 commit comments