From dd2e334ceaf978bd5ce009357edfeb1061a87b71 Mon Sep 17 00:00:00 2001 From: Eduardo Pinho Date: Tue, 27 Jun 2017 17:11:47 +0100 Subject: [PATCH 1/3] Add Session::from_saved_model_to_bundle --- src/session.rs | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/src/session.rs b/src/session.rs index e04f55c268..945c278a07 100644 --- a/src/session.rs +++ b/src/session.rs @@ -4,6 +4,7 @@ use std::ffi::CString; use std::marker; use std::path::Path; use std::ptr; +use super::{Buffer, BufferTrait}; use super::Code; use super::DataType; use super::Graph; @@ -16,6 +17,15 @@ use super::Status; use super::Tensor; use super::TensorType; +/// Aggregation type for a saved model bundle. +#[derive(Debug)] +pub struct SavedModelBundle { + /// The loaded session. + pub session: Session, + /// A meta graph defition as raw protocol buffer. + pub meta_graph_def: Vec, +} + /// Manages a single graph and execution. #[derive(Debug)] pub struct Session { @@ -73,6 +83,51 @@ impl Session { } } + /// Loads a session from an exported model, creating a bundle + pub fn from_saved_model_to_bundle, Tag: AsRef, Tags: IntoIterator> + (options: &SessionOptions, + tags: Tags, + graph: &mut Graph, + export_dir: P) + -> Result { + let mut status = Status::new(); + + let export_dir_cstr = + try!(export_dir.as_ref() + .to_str() + .and_then(|s| CString::new(s.as_bytes()).ok()) + .ok_or_else(|| invalid_arg!("Invalid export directory path"))); + + let tags_cstr: Vec<_> = try!(tags.into_iter() + .map(|t| CString::new(t.as_ref())) + .collect::<::std::result::Result<_, _>>() + .map_err(|_| invalid_arg!("Invalid tag name"))); + let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect(); + + // The empty TF_Buffer will be filled by LoadSessionFromSavedModel + let mut meta = unsafe { Buffer::::from_ptr(ptr::null_mut(), 0) }; + + let inner = unsafe { + tf::TF_LoadSessionFromSavedModel(options.inner, + ptr::null(), + export_dir_cstr.as_ptr(), + tags_ptr.as_ptr(), + tags_ptr.len() as c_int, + graph.inner(), + meta.inner_mut(), + status.inner()) + }; + if inner.is_null() { + Err(status) + } else { + let session = Session { inner: inner }; + Ok(SavedModelBundle { + session: session, + meta_graph_def: Vec::from(meta.as_ref()) + }) + } + } + /// Closes the session. pub fn close(&mut self) -> Result<()> { let mut status = Status::new(); From 43f6bf92d58fa92aca92de480313cd238fa64b6f Mon Sep 17 00:00:00 2001 From: Eduardo Pinho Date: Tue, 4 Jul 2017 10:21:31 +0100 Subject: [PATCH 2/3] move from_saved_model_to_bundle to SavedModelBundle::load --- src/session.rs | 72 ++++++++++++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/src/session.rs b/src/session.rs index 945c278a07..b675e94370 100644 --- a/src/session.rs +++ b/src/session.rs @@ -26,31 +26,15 @@ pub struct SavedModelBundle { pub meta_graph_def: Vec, } -/// Manages a single graph and execution. -#[derive(Debug)] -pub struct Session { - inner: *mut tf::TF_Session, -} - -impl Session { - /// Creates a session. - pub fn new(options: &SessionOptions, graph: &Graph) -> Result { - let mut status = Status::new(); - let inner = unsafe { tf::TF_NewSession(graph.inner(), options.inner, status.inner()) }; - if inner.is_null() { - Err(status) - } else { - Ok(Session { inner: inner }) - } - } +impl SavedModelBundle { - /// Loads a session from an exported model. - pub fn from_saved_model, Tag: AsRef, Tags: IntoIterator> + /// Loads a session from an exported model, creating a bundle + pub fn load, Tag: AsRef, Tags: IntoIterator> (options: &SessionOptions, tags: Tags, graph: &mut Graph, export_dir: P) - -> Result { + -> Result { let mut status = Status::new(); let export_dir_cstr = @@ -63,9 +47,11 @@ impl Session { .map(|t| CString::new(t.as_ref())) .collect::<::std::result::Result<_, _>>() .map_err(|_| invalid_arg!("Invalid tag name"))); - // keeping tags_cstr to retain strings in memory let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect(); + // The empty TF_Buffer will be filled by LoadSessionFromSavedModel + let mut meta = unsafe { Buffer::::from_ptr(ptr::null_mut(), 0) }; + let inner = unsafe { tf::TF_LoadSessionFromSavedModel(options.inner, ptr::null(), @@ -73,9 +59,33 @@ impl Session { tags_ptr.as_ptr(), tags_ptr.len() as c_int, graph.inner(), - ptr::null_mut(), + meta.inner_mut(), status.inner()) }; + if inner.is_null() { + Err(status) + } else { + let session = Session { inner: inner }; + Ok(SavedModelBundle { + session: session, + meta_graph_def: Vec::from(meta.as_ref()) + }) + } + } + +} + +/// Manages a single graph and execution. +#[derive(Debug)] +pub struct Session { + inner: *mut tf::TF_Session, +} + +impl Session { + /// Creates a session. + pub fn new(options: &SessionOptions, graph: &Graph) -> Result { + let mut status = Status::new(); + let inner = unsafe { tf::TF_NewSession(graph.inner(), options.inner, status.inner()) }; if inner.is_null() { Err(status) } else { @@ -83,13 +93,13 @@ impl Session { } } - /// Loads a session from an exported model, creating a bundle - pub fn from_saved_model_to_bundle, Tag: AsRef, Tags: IntoIterator> + /// Loads a session from an exported model. + pub fn from_saved_model, Tag: AsRef, Tags: IntoIterator> (options: &SessionOptions, tags: Tags, graph: &mut Graph, export_dir: P) - -> Result { + -> Result { let mut status = Status::new(); let export_dir_cstr = @@ -102,11 +112,9 @@ impl Session { .map(|t| CString::new(t.as_ref())) .collect::<::std::result::Result<_, _>>() .map_err(|_| invalid_arg!("Invalid tag name"))); + // keeping tags_cstr to retain strings in memory let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect(); - // The empty TF_Buffer will be filled by LoadSessionFromSavedModel - let mut meta = unsafe { Buffer::::from_ptr(ptr::null_mut(), 0) }; - let inner = unsafe { tf::TF_LoadSessionFromSavedModel(options.inner, ptr::null(), @@ -114,17 +122,13 @@ impl Session { tags_ptr.as_ptr(), tags_ptr.len() as c_int, graph.inner(), - meta.inner_mut(), + ptr::null_mut(), status.inner()) }; if inner.is_null() { Err(status) } else { - let session = Session { inner: inner }; - Ok(SavedModelBundle { - session: session, - meta_graph_def: Vec::from(meta.as_ref()) - }) + Ok(Session { inner: inner }) } } From 55a41e12d2a5b360c7aa95b9324659aa2d12b6ca Mon Sep 17 00:00:00 2001 From: Eduardo Pinho Date: Thu, 13 Jul 2017 00:27:47 +0100 Subject: [PATCH 3/3] Improve SavedModelBundle support - Give explicit name to an operation in regression_savedmodel script - Add test_resources with a saved model - Add test for saved model bundle - Fix typo - Wipe try! macros --- examples/regression_savedmodel.py | 2 +- src/session.rs | 62 ++++++++++++++---- .../regression-model/saved_model.pb | Bin 0 -> 17983 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 8 bytes .../variables/variables.index | Bin 0 -> 142 bytes 5 files changed, 49 insertions(+), 15 deletions(-) create mode 100644 test_resources/regression-model/saved_model.pb create mode 100644 test_resources/regression-model/variables/variables.data-00000-of-00001 create mode 100644 test_resources/regression-model/variables/variables.index diff --git a/examples/regression_savedmodel.py b/examples/regression_savedmodel.py index 2cd01147c8..d4879ee15c 100644 --- a/examples/regression_savedmodel.py +++ b/examples/regression_savedmodel.py @@ -10,7 +10,7 @@ w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w') b = tf.Variable(tf.zeros([1]), name='b') -y_hat = w * x + b +y_hat = tf.add(w * x, b, name="y_hat") loss = tf.reduce_mean(tf.square(y_hat - y)) optimizer = tf.train.GradientDescentOptimizer(0.5) diff --git a/src/session.rs b/src/session.rs index b675e94370..7197cd32ac 100644 --- a/src/session.rs +++ b/src/session.rs @@ -22,7 +22,7 @@ use super::TensorType; pub struct SavedModelBundle { /// The loaded session. pub session: Session, - /// A meta graph defition as raw protocol buffer. + /// A meta graph definition as raw protocol buffer. pub meta_graph_def: Vec, } @@ -38,15 +38,15 @@ impl SavedModelBundle { let mut status = Status::new(); let export_dir_cstr = - try!(export_dir.as_ref() + export_dir.as_ref() .to_str() .and_then(|s| CString::new(s.as_bytes()).ok()) - .ok_or_else(|| invalid_arg!("Invalid export directory path"))); + .ok_or_else(|| invalid_arg!("Invalid export directory path"))?; - let tags_cstr: Vec<_> = try!(tags.into_iter() - .map(|t| CString::new(t.as_ref())) - .collect::<::std::result::Result<_, _>>() - .map_err(|_| invalid_arg!("Invalid tag name"))); + let tags_cstr: Vec<_> = tags.into_iter() + .map(|t| CString::new(t.as_ref())) + .collect::<::std::result::Result<_, _>>() + .map_err(|_| invalid_arg!("Invalid tag name"))?; let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect(); // The empty TF_Buffer will be filled by LoadSessionFromSavedModel @@ -102,16 +102,15 @@ impl Session { -> Result { let mut status = Status::new(); - let export_dir_cstr = - try!(export_dir.as_ref() + let export_dir_cstr = export_dir.as_ref() .to_str() .and_then(|s| CString::new(s.as_bytes()).ok()) - .ok_or_else(|| invalid_arg!("Invalid export directory path"))); + .ok_or_else(|| invalid_arg!("Invalid export directory path"))?; - let tags_cstr: Vec<_> = try!(tags.into_iter() - .map(|t| CString::new(t.as_ref())) - .collect::<::std::result::Result<_, _>>() - .map_err(|_| invalid_arg!("Invalid tag name"))); + let tags_cstr: Vec<_> = tags.into_iter() + .map(|t| CString::new(t.as_ref())) + .collect::<::std::result::Result<_, _>>() + .map_err(|_| invalid_arg!("Invalid tag name"))?; // keeping tags_cstr to retain strings in memory let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect(); @@ -385,4 +384,39 @@ mod tests { assert_eq!(output_tensor[0], 4.0); assert_eq!(output_tensor[1], 6.0); } + + #[test] + fn test_savedmodelbundle() { + let mut graph = Graph::new(); + let bundle = SavedModelBundle::load( + &SessionOptions::new(), + &["train", "serve"], + &mut graph, + "test_resources/regression-model", + ).unwrap(); + + let x_op = graph.operation_by_name_required("x").unwrap(); + let y_op = graph.operation_by_name_required("y").unwrap(); + let y_hat_op = graph.operation_by_name_required("y_hat").unwrap(); + let _train_op = graph.operation_by_name_required("train").unwrap(); + + let SavedModelBundle { + mut session, + meta_graph_def, + } = bundle; + + assert!(!meta_graph_def.is_empty()); + + let mut x = >::new(&[1]); + x[0] = 2.0; + let mut y = >::new(&[1]); + y[0] = 4.0; + let mut step = StepWithGraph::new(); + step.add_input(&x_op, 0, &x); + step.add_input(&y_op, 0, &y); + let output_token = step.request_output(&y_hat_op, 0); + session.run(&mut step).unwrap(); + let output_tensor = step.take_output::(output_token).unwrap(); + assert_eq!(output_tensor.len(), 1); + } } diff --git a/test_resources/regression-model/saved_model.pb b/test_resources/regression-model/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..2cd4fc4d927635d0b590d726ce640e8d50f98307 GIT binary patch literal 17983 zcmc&++mBmE8Taws-(GWa+e~ha(~{IB_3kEVQmfEp6T(GG8oHqrTE!gOXLsFrZF7#D z-3==0LxsvksJtLl3Go6FZ@lyY5E6d?0ttBJkAZJy&dixPo-^xBBbtZUK9~9Kzwi6z zvSr8p)6X3E^tk)sad;JSH#;4-2+nTRc|bq+@RQ)@6a1_az>E3O-q3F>z})K0VJDj} z7IKxcTY?W9IC671?CpIu@H$<;Kf3LQZTxU&Aoth{@bwQI#|0>Oz2TNeN5}6Szc-@a z2|hF!kN88?(=izr7FXw;Y$caJR4A6+nc2C+)#ETT4*h0t(0|GZZuTk7z~6{zKK6QJpVC31NM|@6HRb(kMUq1S=kd{Q#~b<0aLXI|a&#w+tn(|! zfj8mkS}^cBZ7&=N`ELfBA?hU?E+b*~WrcEpR|0y4eW6#XM>L3I|`C2kG+i(ciw@Xf-lK|dUk1t@Z9tA!4=T6F;m zOcic%o9kFukUSl+v}t5sJ-RRL#f;EblZNol&d-STqK0bp(ev$ zXtmHd6JYQ9m1;%YuuakB6-l)v6V$L)^xD5KpzdW((G z29)l4yWQ>ab}FAoVmcKltrjqDa+QzKZ+_xH{;u!!RYL9wI28DuaeLGq^qZ0PtN91r z&aSSkqfmL|`@<$xs`S&$e$lf8Y$$GU&EZ<$-lf)jYxUNa-+nY4pe;l9I2>-vpUokL zO?MZ)q2wL|x8sMSZr@{!@;ghV>Pc9@u=X*62E7h47sSjt@EYXqj(e$!G6D=YC{@=b z$ldcd#oH8?C_Gh<#p<1b%4&(Jx9kE&(!IeOL%0k1`(FDII&+b^U5GD5?i;&an)_nR zEc2k4uMi+ygEe;?-F z^!lB__P6@ojX|(YnPNOszc+wLKJ@*LR7iCZpT-NT@s(p5tTakcSk0kTs)u4rMKJdf zn)Q27K=PZui<%2ZUNGX`PJi~g+g;@+&lfxv(3+9;*l)z{q+_#~rlAHeLLMnRQk6|W zyuXxFapy}=deirMPYXNm(AT-W5!oV?Mf^wPX& zNFf}$P+a#OKP4&mK&0+NVV$L0B^oeS3j>$Y&`p-W6h)l^dD)<(nT~^0JmDq1udq=`4mjjb!_hIM#C0>ttLdltS8Zesjx3<^1V^2I!48yQW~+EKy6>>89EH&e@J zG9UcaG_bcT1~$I)x<2l)39cD&__+h85-ejjQqrxSeTh+jHl0v5cm~G$MCYkQGcL9j zH2|KqX$6CoUV@(x;ecWSN#5A@b|3vAUUSPE z@yMg2kcc@;QWCOb%bt@s#K$4d-ZW&d#}yT?tOy-w8j@F|#(=d4B&y5y2e-R0gHW^S z*IBzF{5K5~6~E5aT;aQwwG!sPOvA}wgBlIT`WrFa;{ zN{^-nI=jHq#Q6zaKLiVnId`V~U-8d@ccNJfHg_SkMmW(aw2nBa%v|8ioyq~c3k#dk z6k=FkXVnUC4hSO``@_7_0MHCK73&Yw(b+7SP`9GObUmBP0943)%rl$pVpCouDT#W4 z5s#%w52`IV5f4shi<opYL{eFrPR2+hIrz{q)7!>L7(UgAQ-vgrrDw)Q04GgSGe>W(xRI+H*+5RM z`5|FWX%XDv}l=5XMrjyYr%46?Z!*@r*rLet6nF9*(<1lHFS^d$x}q%fPBgAt>^ z{0<8tvwf*LVp$Lp?z!2e>~umk;j~GnXwJXBC-TQrr|VxjW=IgL3SxPI?5|s5XLTxp zrJh*dY7%{euFG_AT}H0Z*bHLyzeu(NGK9+|#?i!#tC)HviEKQD=7iLarGA-}Q?poZ zz{e?}$EKy&0KCMrRlCBfshJRbv!2<| zv|iveG#PEmFi*c_;1dUk$swuaNM>IRTsEsGV$)jj@2`WD;vxc(=87GJHQl82Xz(M) z3iGG}3BwjWnQI7z&(wlPEmMt2S%5Bb62oq)%UTmbb7g;(56$7gb|kQU3Fguv=8{%R zbkhRsDpRyhpJDqJ==}5hW=9_dW1r?4u_|Xy@Dy~-cPPgeOqY1}LzO(AMc{oFPGZ!> zG%z{!&=e#QkH(4R%#Bo=D~{=4oGZOmfQM@2K84oVp{Sva|rY?-DRT(SnA zI&d-e2T8`f70A3pNOf8{0Iy^foYC)K{H{(+(>w=f64Nw^pMvci{MLbE2Da(a9$nv4 zkvw7I<_eiw;Hfa*Qf|RVT1t{Cm?-vAWJ%X-O)KQ4aogW()q!LMl|^G_U!Bzq+BP~# zV`U0%tcr#28f{Cr&CtCwW~ERugnQU5g8KxOs#U5XvIwd|7T_kEQ1UjP7Eo+afo3?^ zg!2Z)j>fqDTi*-@W850$4{`Sq-2_BVD#)GWoSe4YY_7oH9C#^}QXAKJgF%nd^82`P zti7i!keK9+)J2{*5lwPGg;{Esat3L*IBf9WU+(>1(%_&bxgJz8ve5~f5 zQlixUrB;`Mya4CjbNdLQ_$+Lwc+Zv~44H+lN5uM^3hER20qH0jNYbI1g_q;zD*FJ7 z+<$gJ0S$$a5jNbE+@g4&G6c^qRyjUfe)*$@mMUeDVGa^az#{o>Ve7k>GR~|@{O3NktI0l(kvK5&%F2D=!`9#)GxRO3+ zT}2TOZofNaJn+69dxRzhp2>T=`od6 zYxz=5+m&2myOOhOPHWBC!IXQ7OUcQg>W02@GE0y;F8p$SZ5I8K?rYAjW!5s5^rZ6u zKfpcYG@jNOAvZd}MT|NYjyEYfv2shC7Tx&0C=ZUYqkMRd?-#hkj!|#hrbNxW%sy23 z7Jyc&U68%qX*lVgV7(oGQ1`}E1OFOp5cb04(bl@*Tq`bg<4MspeQ_QqrL)+_5i#q~ z#ZGpRl%COu@=OU=J-Nyz*>er2amhhsQ9%ZmCMQni(oOwOV2)`=o_LO@dNqhjlbkH^ zJCp3EStJ%6cQ!EESf=nJjmGdx;5^&wk5Oe+xtb_&qsrL?nY8drk92M3wxqA~@;G(Y zGIU4KyiUrE&ugmIAU46te(Yje5Jxr(kXyyGQ>lF|;ED|ggyb2L7D=#bRjB!ph_Rv& z_BEa%*@V-&f^hGK33JUx-_h~bVlH?O&P@>m>~w1Y=WT&Ajk$2>oC8&!x=3gFC=c@8Xsbv*me zITmkv_!lMk%HCoAq`1NAl4`bmjJyW3j2J8W8p|#(Aq3PlmdV73&lN5@S2H(G-MOBv zeg!J{cT3pe&(OW@z5*v8w}Yd|16c&W(eqG^WqPYcZ;85=;5U8?|GGz+70b=R5YHUr zNqF|F`bN}u(X*w^XF#X+7zmFf+=8=Hr-w(enL{QvQ7@tUHY~DZqLABdEH65Q{s{uS df(Je!ww mL2y1T+`z!dRKZ{g7YcsIA8W_q%?A=71j2tebgPuQ-v$6Uju5s0 literal 0 HcmV?d00001