@@ -676,47 +676,62 @@ def _validate_args(args):
676
676
)
677
677
678
678
679
- def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
680
- _validate_args (args )
681
-
682
- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
683
-
684
- # export_to_edge
685
- builder_exported = _prepare_for_llama_export (args ).export ()
686
-
687
- builder_exported .run_canonical_optimizations ()
688
-
689
- if args .export_only :
690
- exit ()
691
-
692
- builder_exported_to_edge = builder_exported .pt2e_quantize (
693
- quantizers
694
- ).export_to_edge ()
695
-
696
- modelname = builder_exported_to_edge .modelname
697
-
698
- # to_backend
679
+ def _to_edge_and_lower_llama_xnnpack (
680
+ builder_exported ,
681
+ modelname ,
682
+ additional_passes ,
683
+ pt2e_quant_params ,
684
+ quantizers ,
685
+ quant_dtype ,
686
+ args ,
687
+ ) -> LLMEdgeManager : # noqa: C901
699
688
partitioners = []
700
689
701
690
# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
702
- if (
703
- pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None
704
- ) or (args .xnnpack ):
705
- partitioners .append (
706
- get_xnnpack_partitioner (dynamic_quant_only_partitioner = True )
707
- )
691
+ partitioners .append (get_xnnpack_partitioner (dynamic_quant_only_partitioner = True ))
708
692
709
- # force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
710
- args .xnnpack = True
711
- modelname = f"xnnpack_dq_{ modelname } "
693
+ modelname = f"xnnpack_dq_{ modelname } "
712
694
713
695
if args .xnnpack_extended_ops :
714
- assert args .xnnpack , "xnnpack_extended_ops requires xnnpack to be enabled"
715
696
partitioners .append (
716
697
get_xnnpack_partitioner (dynamic_quant_only_partitioner = False )
717
698
)
718
699
modelname = f"xnnpack_{ modelname } "
719
700
701
+ logging .info ("Lowering model using following partitioner(s): " )
702
+ for partitioner in partitioners :
703
+ logging .info (f"--> { partitioner .__class__ .__name__ } " )
704
+
705
+ # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
706
+ if args .generate_etrecord :
707
+ raise NotImplementedError (
708
+ "export_llama does not support XNNPack and generating ETRecord at the moment."
709
+ )
710
+
711
+ builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
712
+ partitioners
713
+ )
714
+ if args .verbose :
715
+ print_delegation_info (builder .edge_manager .exported_program ().graph_module )
716
+
717
+ return builder .to_executorch (passes = additional_passes )
718
+
719
+
720
+ def _to_edge_and_lower_llama ( # noqa: C901
721
+ builder_exported ,
722
+ modelname ,
723
+ additional_passes ,
724
+ pt2e_quant_params ,
725
+ quantizers ,
726
+ quant_dtype ,
727
+ args ,
728
+ ):
729
+ builder_exported_to_edge = builder_exported .pt2e_quantize (
730
+ quantizers
731
+ ).export_to_edge ()
732
+
733
+ # to_backend
734
+ partitioners = []
720
735
if args .vulkan :
721
736
partitioners .append (
722
737
get_vulkan_partitioner (
@@ -731,7 +746,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
731
746
modelname = f"vulkan_{ modelname } "
732
747
733
748
# Need to remove asserts from the graph to prevent graph breaks
734
- # pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
735
749
remove_asserts (builder_exported_to_edge .edge_manager .exported_program ())
736
750
737
751
if args .mps :
@@ -760,13 +774,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
760
774
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
761
775
from executorch .backends .qualcomm .utils .utils import _transform , tag_quant_io
762
776
763
- # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
764
777
_transform (builder_exported_to_edge .edge_manager .exported_program ())
765
778
766
779
if args .num_sharding > 0 :
767
780
model_sharding .split_graph (
768
781
builder_exported_to_edge .edge_manager .exported_program (),
769
- # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
770
782
builder_exported_to_edge .metadata ["get_n_layers" ],
771
783
shares = args .num_sharding ,
772
784
)
@@ -792,19 +804,15 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
792
804
atten .head_dim ,
793
805
)
794
806
)
795
- # pyre-ignore
796
807
tag_quant_io (
797
808
builder_exported_to_edge .edge_manager .exported_program ().graph_module ,
798
- partial (get_custom_quant_ios_dtype , cache_shape ), # pyre-ignore
809
+ partial (get_custom_quant_ios_dtype , cache_shape ),
799
810
)
800
811
801
812
logging .info ("Lowering model using following partitioner(s): " )
802
813
for partitioner in partitioners :
803
814
logging .info (f"--> { partitioner .__class__ .__name__ } " )
804
815
805
- additional_passes = []
806
- if args .model in TORCHTUNE_DEFINED_MODELS :
807
- additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
808
816
if args .generate_etrecord :
809
817
if not builder_exported_to_edge .edge_manager :
810
818
raise ValueError ("Unable to generate etrecord due to missing edge manager." )
@@ -818,7 +826,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
818
826
if args .num_sharding > 0 and args .qnn :
819
827
from executorch .backends .qualcomm .utils .utils import canonicalize_program
820
828
821
- # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
822
829
canonicalize_program (builder .edge_manager .exported_program ())
823
830
824
831
builder = builder .to_executorch (
@@ -840,11 +847,55 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
840
847
if args .num_sharding > 0 and args .qnn :
841
848
from executorch .backends .qualcomm .utils .utils import canonicalize_program
842
849
843
- # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
844
850
canonicalize_program (builder .edge_manager .exported_program ())
845
851
846
852
builder = builder .to_executorch (passes = additional_passes )
847
853
854
+ return builder
855
+
856
+
857
+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
858
+ _validate_args (args )
859
+
860
+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
861
+
862
+ additional_passes = []
863
+ if args .model in TORCHTUNE_DEFINED_MODELS :
864
+ additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
865
+
866
+ # export_to_edge
867
+ builder_exported = _prepare_for_llama_export (args ).export ()
868
+ builder_exported .run_canonical_optimizations ()
869
+ modelname = builder_exported .modelname
870
+
871
+ if args .export_only :
872
+ exit ()
873
+
874
+ if pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None :
875
+ # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
876
+ args .xnnpack = True
877
+
878
+ if args .xnnpack :
879
+ builder = _to_edge_and_lower_llama_xnnpack (
880
+ builder_exported ,
881
+ modelname ,
882
+ additional_passes ,
883
+ pt2e_quant_params ,
884
+ quantizers ,
885
+ quant_dtype ,
886
+ args ,
887
+ )
888
+ else :
889
+ builder = _to_edge_and_lower_llama (
890
+ builder_exported ,
891
+ modelname ,
892
+ additional_passes ,
893
+ pt2e_quant_params ,
894
+ quantizers ,
895
+ quant_dtype ,
896
+ args ,
897
+ )
898
+
848
899
if args .profile_memory :
849
900
generate_memory_trace (builder .export_program , "memory_profile.json" )
850
901
@@ -866,7 +917,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
866
917
output_file = f"{ builder .output_dir } /{ modelname } .pte"
867
918
868
919
builder .save_to_pte (output_file )
869
-
870
920
return builder
871
921
872
922
0 commit comments