Skip to content

Saving a SavedModel after loading a graph #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
lucamc9 opened this issue Oct 15, 2020 · 5 comments
Open

Saving a SavedModel after loading a graph #279

lucamc9 opened this issue Oct 15, 2020 · 5 comments

Comments

@lucamc9
Copy link

lucamc9 commented Oct 15, 2020

Hi, is there a way to save the SavedModel after training having loaded the graph from a SavedModel (as opposed to initialising the layers & vars like in examples/xor.rs)?

It seems like the SavedModelBuilder requires a collection of variables & scope which doesn't seem straightforward to get just from an already existing graph.

@kotatsuyaki
Copy link

kotatsuyaki commented Dec 19, 2020

I'm also interested in this, but have yet to found any possible workaround. It seems like while the signatures can be obtained from MetaGraphDef::signatures, there's no way to get the collection of variables and scopes when a model is loaded via SavedModelBundle::load or Session::from_saved_model.

EDIT 2020/12/20: What I wanted to do is to build the model in Python, and then train the model from Rust, so I'm looking for a way to keep all the trained parameters, either using SavedModel or checkpoints. Things I've tried so far:

  • Make the saving operations part of the saved model itself
    • Include the saving operations (tf.saved_model.save() or model.save()) in a @tf.function from the Python side. This gives me error messages like save is not supported inside a traced @tf.function so it doesn't work.
    • Include the Keras model.save_weights() in a @tf.function from the Python side. This gives me error messages like RuntimeError: Cannot get session inside Tensorflow graph function.
  • Try to find a way to get the parameters need for tensorflow::SavedModelBuilder from the Rust side, but this is beyond my knowledge about TensorFlow.
  • Try to use the tf.train.Saver operation, as suggested in Possibility to Save/Restore Checkpoints #30, but this seems to be no longer available in TensorFlow 2.0 and onwards.

Another interesting finding is that in the experimental C API (c_api_experimental.h), while there's TF_CheckpointReader type for reading checkpoints and TF_LoadSessionFromSavedModel for reading saved models, there seems to be no functions available for saving states.

@kotatsuyaki
Copy link

I found a (quite hacky and ugly) workaround to this.

  • From the Python side, include a tf.train.Checkpoint.write call inside a @tf.function. Unlike tf.train.Checkpoint.save, the write function is written purely using tf operations, so it runs fine even in graph mode.
  • Save a concrete function of that together with the model. For example,
    tf.saved_model.save(my_model, modelpath, signatures={
      'ckpt_write': my_model.ckpt_write,
      # Other functions to be saved ...
    })
  • Run that ckpt_write function from the Rust side. For example,
    let mut args = tf::SessionRunArgs::new();
    // name of output can  be obtained by something like this:
    // bundle.meta_graph_def().signatures()["ckpt_write"].outputs()["output_0"].name()
    args.add_target(
        &graph
            .operation_by_name(&name_of_output_of_ckpt_write)?
            .unwrap(),
    );
    session.run(&mut args)?;

This makes it possible to save the state of a network trained from the Rust side, but I still haven't found any way to restore from the checkpoint using Rust. The restoring can, however, be done by:

  • Read the checkpoint using Python tf.train.Checkpoint.read function.
  • Save the whole model again (for Rust to read the whole model again using SavedModelBundle::load).

@Trolldemorted
Copy link

It has been 2 years, is there a nice way of doing things by now?

@AcrylicShrimp
Copy link

// TODO: support all fields

Well, in 2023, there's no support for other fields too.

@bitmagier
Copy link

What a pity that this essential functionality is not yet covered properly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants