-
Notifications
You must be signed in to change notification settings - Fork 340
Adding File naming capability to from_saved_model and from_keras_model. #375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
+12
−8
Merged
Changes from 15 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
f0aedda
Firebase ML Kit Modify Operation Handling to not require a name for D…
ifielker 8939ed6
fixed lint
ifielker 7078cb6
Adding support for TensorFlow 2.x
ifielker d88a66b
fix typo
ifielker 8b5a6b3
remove extraneous @type from operations
ifielker 7ee369d
send updateMask in query parameter
ifielker d1a9933
send list filters etc in query parameters
ifielker 68ebe77
fix typo
ifielker 5633b9f
fix typo
ifielker d715628
fix typo
ifielker b63a4fe
urlEncode filter string
ifielker 71e6705
adding optional file names to conversion functions (start)
ifielker 6dc6d0d
merged
ifielker d24e742
adding File naming capability for ModelSource
ifielker 0233bfa
fixed lint
ifielker 7230667
fixed file descriptor leak
ifielker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -524,11 +524,13 @@ def _tf_convert_from_keras_model(keras_model): | |
return converter.convert() | ||
|
||
@classmethod | ||
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): | ||
def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', | ||
bucket_name=None, app=None): | ||
"""Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. | ||
|
||
Args: | ||
saved_model_dir: The saved model directory. | ||
model_file_name: The name that the tflite model will be saved as in Cloud Storage. | ||
bucket_name: The name of an existing bucket. None to use the default bucket configured | ||
in the app. | ||
app: Optional. A Firebase app instance (or None to use the default app) | ||
|
@@ -541,16 +543,17 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): | |
""" | ||
TFLiteGCSModelSource._assert_tf_enabled() | ||
tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) | ||
open('firebase_ml_model.tflite', 'wb').write(tflite_model) | ||
return TFLiteGCSModelSource.from_tflite_model_file( | ||
'firebase_ml_model.tflite', bucket_name, app) | ||
open(model_file_name, 'wb').write(tflite_model) | ||
return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) | ||
|
||
@classmethod | ||
def from_keras_model(cls, keras_model, bucket_name=None, app=None): | ||
def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', | ||
bucket_name=None, app=None): | ||
"""Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. | ||
|
||
Args: | ||
keras_model: A tf.keras model. | ||
model_file_name: The name that the tflite model will be saved as in Cloud Storage. | ||
bucket_name: The name of an existing bucket. None to use the default bucket configured | ||
in the app. | ||
app: Optional. A Firebase app instance (or None to use the default app) | ||
|
@@ -563,9 +566,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): | |
""" | ||
TFLiteGCSModelSource._assert_tf_enabled() | ||
tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) | ||
open('firebase_ml_model.tflite', 'wb').write(tflite_model) | ||
return TFLiteGCSModelSource.from_tflite_model_file( | ||
'firebase_ml_model.tflite', bucket_name, app) | ||
open(model_file_name, 'wb').write(tflite_model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here too. |
||
return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) | ||
|
||
@property | ||
def gcs_tflite_uri(self): | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed this earlier. But this is leaking a file descriptor. Can you change this to: