-
Notifications
You must be signed in to change notification settings - Fork 428
Load from saved model support #68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- add TF_LoadSessionFromSavedModel function to tensorflow-sys - add function in Session that provides safe loading from a saved model
src/session.rs
Outdated
use std::ptr; | ||
use std::result::Result as StdResult; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm I find this confusing, but if @adamcrume is good with it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I just needed a hint on line 56 to pluck the result out, and this approach sounded elegant to me. We could consider something else though...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other places we're naming std::result::Result
directly without importing it, like so:
pub type Result<T> = std::result::Result<T, Status>;
I'd prefer that, just for consistency. On a side note, I regret adding that type alias. I assumed it was an accepted pattern because of std::fmt::Result
, std::io::Result
, and std::thread::Result
, but having multiple types with the same name (just in different modules) causes no end of headaches.
src/session.rs
Outdated
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect(); | ||
|
||
let inner = unsafe { | ||
tf::TF_LoadSessionFromSavedModel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
formatting looks a bit off here, can you run it through rustfmt please?
@Enet4 I'd love to see at least an example on this, since we are lacking them anyways and it also helps other people onboard more quickly |
@daschl All right, I have formatted the code (my bad!) and added an example based on regression, which was already available. Still, I'm open to file renames or other tweaks. |
@Enet4 very cool, thanks! Of course @adamcrume has the final say on this ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good, with a few tweaks. Please run the code through rustfmt. Also, thanks for the example code; we can always use more examples and more tests.
src/session.rs
Outdated
use std::ptr; | ||
use std::result::Result as StdResult; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other places we're naming std::result::Result
directly without importing it, like so:
pub type Result<T> = std::result::Result<T, Status>;
I'd prefer that, just for consistency. On a side note, I regret adding that type alias. I assumed it was an accepted pattern because of std::fmt::Result
, std::io::Result
, and std::thread::Result
, but having multiple types with the same name (just in different modules) causes no end of headaches.
src/session.rs
Outdated
.to_str() | ||
.and_then(|s| CString::new(s.as_bytes()).ok()) | ||
.ok_or_else(|| { | ||
Status::new_set(Code::InvalidArgument, "Invalid export directory path").unwrap() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the invalid_arg!
macro for this.
src/session.rs
Outdated
.map(|t| CString::new(t.as_ref())) | ||
.collect::<StdResult<_, _>>() | ||
.map_err(|_| { | ||
Status::new_set(Code::InvalidArgument, "Invalid tag name").unwrap() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
invalid_arg!
again.
src/session.rs
Outdated
let inner = unsafe { | ||
tf::TF_LoadSessionFromSavedModel(options.inner, | ||
ptr::null(), | ||
export_dir_cstr.to_bytes_with_nul().as_ptr() as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return value of to_bytes_with_nul
doesn't live long enough. You can just use export_dir_cstr.as_ptr()
, since it already guarantees a null terminator.
- use invalid_arg! - remove std result alias - fix getting char pointer to export_dir
I still wonder how I got the tag conversion right and then screwed up on the export_dir string... 🤔 Nevertheless, the changes were made. :) |
Thanks! |
Note that we're preferring to add a SavedModelBundle wrapper for the return value in other languages. You need the MetaGraphDef to extract out the signatures in the SavedModel. |
@jhseu Admittedly, I knew that at least the Java bindings would be doing it with a bundle (tensorflow/tensorflow#7134), but I had found no reason to replicate that design here, at the time. But given that the meta-graph is still unreachable, I agree that we should still seek to improve this saved model API. Also, a proper MetaGraph abstraction would probably be nicer than just retrieving a byte buffer. |
* added new row indexer for parquet data frame * updated all tests and code to use DateTimeOffset * added logical JSON type * added new dataset handling of rows through pivoting * Update PlainValuesReader.cs * built more single responsibility around ParquetReader type to ensure efficient deallocation of resources using IDisposable * updated reader to look at nulls * added branches to set type IList as either nullable or non-nullable and done this against the required attribute on the column header * moved BigDecimal to own file
This PR exposes TensorFlow's native capability of loading saved model bundles from a directory.
Session::from_saved_model
, which provides a safe generic API with the minimum arguments required. In the future, one might consider adding alternative functions that would let users specifyrun_options
and retrievemeta_graph_def
.This feature will hopefully make loading pre-trained models in Rust more accessible. Please let me know if you would like a complete example or an integration test. Some of the examples found in this repository should be easily adjusted to test this feature.
Example of use: