99import base64
1010import copy
1111import dataclasses
12+ import io
1213import json
1314import logging
1415import operator
16+ import os
17+ import zipfile
1518from typing import Any , Callable , Dict , List , Optional , Union
1619
1720import executorch .exir as exir
3033from executorch .exir .lowered_backend_module import (
3134 LoweredBackendModule as ExirLoweredBackendModule ,
3235)
36+ from executorch .exir .serde .export_serialize import SerializedArtifact
3337from executorch .exir .serde .schema import (
3438 CompileSpec ,
3539 LoweredBackendModule as SerdeLoweredBackendModule ,
40+ SCHEMA_VERSION ,
3641)
3742from torch ._export .serde .schema import SchemaVersion
3843from torch ._export .serde .serialize import SerializeError
@@ -628,7 +633,7 @@ class ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer):
628633 def deserialize (
629634 self ,
630635 serialized_artifact : export_serialize .SerializedArtifact ,
631- ) -> exir .ExportedProgram :
636+ ) -> ep .ExportedProgram :
632637 assert isinstance (serialized_artifact .exported_program , schema .ExportedProgram )
633638
634639 symbol_name_to_range = {
@@ -738,7 +743,7 @@ def serialize(
738743def deserialize (
739744 artifact : export_serialize .SerializedArtifact ,
740745 expected_opset_version : Optional [Dict [str , int ]] = None ,
741- ) -> exir .ExportedProgram :
746+ ) -> ep .ExportedProgram :
742747 assert isinstance (artifact .exported_program , bytes )
743748 exported_program_str = artifact .exported_program .decode ("utf-8" )
744749 exported_program_dict = json .loads (exported_program_str )
@@ -750,3 +755,96 @@ def deserialize(
750755 serialized_exported_program , artifact .state_dict , artifact .constants
751756 )
752757 )
758+
759+
760+ def save (
761+ ep_save : ep .ExportedProgram ,
762+ f : Union [str , os .PathLike , io .BytesIO ],
763+ * ,
764+ extra_files : Optional [Dict [str , Any ]] = None ,
765+ opset_version : Optional [Dict [str , int ]] = None ,
766+ ) -> None :
767+ if not isinstance (ep_save , ep .ExportedProgram ):
768+ raise TypeError (f"save() expects an ExportedProgram but got { type (ep )} " )
769+
770+ artifact : SerializedArtifact = serialize (ep_save , opset_version )
771+
772+ if isinstance (f , (str , os .PathLike )):
773+ f = os .fspath (f )
774+
775+ with zipfile .ZipFile (f , "w" ) as zipf :
776+ # Save every field in the SerializedArtifact to a file.
777+ assert isinstance (artifact .exported_program , bytes )
778+ zipf .writestr ("serialized_exported_program.json" , artifact .exported_program )
779+ zipf .writestr ("serialized_state_dict.pt" , artifact .state_dict )
780+ zipf .writestr ("serialized_constants.pt" , artifact .constants )
781+
782+ zipf .writestr ("version" , "." .join (map (str , SCHEMA_VERSION )))
783+
784+ # Add extra files if provided
785+ if extra_files :
786+ for extra_file_name , content in extra_files .items ():
787+ encoded_content = content .encode ("utf-8" )
788+ zipf .writestr (f"extra_files/{ extra_file_name } " , encoded_content )
789+
790+
791+ def load (
792+ f : Union [str , os .PathLike , io .BytesIO ],
793+ * ,
794+ extra_files : Optional [Dict [str , Any ]] = None ,
795+ expected_opset_version : Optional [Dict [str , int ]] = None ,
796+ ) -> ep .ExportedProgram :
797+ if isinstance (f , (str , os .PathLike )):
798+ f = os .fspath (f )
799+
800+ extra_files = extra_files or {}
801+
802+ with zipfile .ZipFile (f , "r" ) as zipf :
803+ # Check the version
804+ version = zipf .read ("version" ).decode ().split ("." )
805+
806+ assert len (version ) == len (SCHEMA_VERSION )
807+ if version [0 ] != str (SCHEMA_VERSION [0 ]):
808+ raise RuntimeError (
809+ f"Serialized version { version } does not match our current "
810+ f"schema version { SCHEMA_VERSION } ."
811+ )
812+
813+ # Load serialized_ep and serialized_state_dict from the zip file
814+
815+ serialized_exported_program : Optional [bytes ] = None
816+ serialized_state_dict : Optional [bytes ] = None
817+ serialized_constants : Optional [bytes ] = None
818+
819+ for file_info in zipf .infolist ():
820+ file_content = zipf .read (file_info .filename )
821+
822+ if file_info .filename == "serialized_exported_program.json" :
823+ serialized_exported_program = file_content
824+ elif file_info .filename == "serialized_state_dict.json" :
825+ print ("This version of file is deprecated" )
826+ serialized_state_dict = file_content
827+ elif file_info .filename == "serialized_constants.json" :
828+ print ("This version of file is deprecated" )
829+ serialized_constants = file_content
830+ elif file_info .filename == "serialized_state_dict.pt" :
831+ serialized_state_dict = file_content
832+ elif file_info .filename == "serialized_constants.pt" :
833+ serialized_constants = file_content
834+ elif file_info .filename .startswith ("extra_files" ):
835+ filename = file_info .filename .split ("/" , 1 )[1 ]
836+ extra_files [filename ] = file_content .decode ("utf-8" )
837+
838+ assert serialized_exported_program is not None
839+ assert serialized_state_dict is not None
840+ assert serialized_constants is not None
841+ artifact : SerializedArtifact = SerializedArtifact (
842+ serialized_exported_program ,
843+ serialized_state_dict ,
844+ serialized_constants ,
845+ )
846+
847+ # Deserialize ExportedProgram
848+ ep = deserialize (artifact , expected_opset_version )
849+
850+ return ep
0 commit comments