101101 "phi_4_mini" ,
102102 "smollm2" ,
103103]
104- TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision" ]
104+ TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision" , "llama3_2_lora" ]
105105HUGGING_FACE_REPO_IDS = {
106106 "qwen2_5" : "Qwen/Qwen2.5-1.5B" ,
107107 "phi_4_mini" : "microsoft/Phi-4-mini-instruct" ,
@@ -209,6 +209,12 @@ def build_args_parser() -> argparse.ArgumentParser:
209209 help = "checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set." ,
210210 )
211211
212+ parser .add_argument (
213+ "--adapter" ,
214+ default = None ,
215+ help = "Adapter path" ,
216+ )
217+
212218 parser .add_argument (
213219 "--use_qnn_sha" ,
214220 action = "store_true" ,
@@ -585,17 +591,20 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
585591 checkpoint_dir = (
586592 canonical_path (args .checkpoint_dir ) if args .checkpoint_dir else None
587593 )
594+ adapter_path = canonical_path (args .adapter ) if args .adapter else None
588595 params_path = canonical_path (args .params ) if args .params else None
589596 output_dir_path = canonical_path (args .output_dir , dir = True )
590597 weight_type = WeightType .FAIRSEQ2 if args .fairseq2 else WeightType .LLAMA
591598
592599 # Convert dtype override string arg to actual type.
593600 dtype_override = DType [args .dtype_override ]
594601
602+ # breakpoint() # 1, OK.
595603 edge_manager = _load_llama_model (
596604 args .model ,
597605 checkpoint = checkpoint_path ,
598606 checkpoint_dir = checkpoint_dir ,
607+ adapter = adapter_path ,
599608 params_path = params_path ,
600609 use_kv_cache = args .use_kv_cache ,
601610 use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
@@ -616,10 +625,16 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
616625 dtype_override = dtype_override ,
617626 args = args ,
618627 )
619-
620628 # At this point, the model is loaded in the default fp32.
621629
622630 # Checkpoint dtype should be lower or equal precision to the dtype override.
631+ eg = torch .tensor ([[2 , 3 , 4 ]], dtype = torch .int64 )
632+ ip = torch .tensor ([[0 , 1 , 2 ]], dtype = torch .long )
633+
634+ em1 = edge_manager .model .forward (eg , input_pos = ip )
635+ eager = torch .load ("/data/users/lfq/executorch/eager_res.pt" )
636+ torch .allclose (eager , em1 )
637+ # breakpoint() # 4, OK.
623638 checkpoint_dtype = edge_manager .model .checkpoint_dtype
624639 if not (
625640 checkpoint_dtype == dtype_override .to_torch_dtype ()
@@ -637,6 +652,10 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
637652 )
638653
639654 edge_manager .model = edge_manager .model .to (dtype = dtype_override .to_torch_dtype ())
655+ # edge_manager.model = edge_manager.model.to(dtype=torch.float32)
656+ em2 = edge_manager .model .forward (eg , input_pos = ip )
657+ torch .allclose (em2 , eager )
658+ # breakpoint() # 5, not OK, gets converted to bf16. OK if dtype is consistent.
640659
641660 # We want to quantize (in the source transforms) the weights of the model
642661 # in the checkpoint dtype.
@@ -649,7 +668,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
649668 args = args ,
650669 )
651670 )
652-
671+ # torch.allclose here as well.
672+ em3 = edge_manager .model .forward (eg , input_pos = ip )
673+ torch .allclose (em3 , eager )
653674 return edge_manager
654675
655676
@@ -777,6 +798,9 @@ def _to_edge_and_lower_llama( # noqa: C901
777798 builder_exported_to_edge = builder_exported .pt2e_quantize (
778799 quantizers
779800 ).export_to_edge ()
801+ breakpoint ()
802+ # ^to_edge_res.pt
803+ # allclose 1e-1 compared to pre-auto.
780804
781805 # to_backend
782806 partitioners = []
@@ -911,7 +935,16 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
911935
912936 # export_to_edge
913937 builder_exported = _prepare_for_llama_export (args ).export ()
938+ eg = torch .tensor ([[2 , 3 , 4 ]], dtype = torch .int64 )
939+ ip = torch .tensor ([[0 , 1 , 2 ]], dtype = torch .long )
940+ b_e = builder_exported .model .forward (eg , input_pos = ip )
941+ eager = torch .load ("/data/users/lfq/executorch/eager_res.pt" )
942+ torch .allclose (b_e , eager )
943+ # breakpoint()
944+
914945 builder_exported .run_canonical_optimizations ()
946+ b_e2 = builder_exported .model .forward (eg , input_pos = ip )
947+ torch .allclose (b_e2 , eager )
915948 modelname = builder_exported .modelname
916949
917950 if args .export_only :
@@ -932,6 +965,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
932965 args ,
933966 )
934967 else :
968+ # breakpoint()
969+ b_e3 = builder_exported .model .forward (eg , input_pos = ip )
970+ torch .allclose (b_e3 , eager )
935971 builder = _to_edge_and_lower_llama (
936972 builder_exported ,
937973 modelname ,
@@ -941,6 +977,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
941977 quant_dtype ,
942978 args ,
943979 )
980+ breakpoint ()
944981
945982 if args .profile_memory :
946983 generate_memory_trace (builder .export_program , "memory_profile.json" )
@@ -1004,6 +1041,7 @@ def _load_llama_model(
10041041 * ,
10051042 checkpoint : Optional [str ] = None ,
10061043 checkpoint_dir : Optional [str ] = None ,
1044+ adapter : Optional [str ] = None ,
10071045 params_path : Optional [str ] = None ,
10081046 use_kv_cache : bool = False ,
10091047 use_sdpa_with_kv_cache : bool = False ,
@@ -1038,6 +1076,9 @@ def _load_llama_model(
10381076 if modelname == "llama3_2_vision" :
10391077 module_name = "llama3_2_vision"
10401078 model_class_name = "Llama3_2Decoder"
1079+ if modelname == "llama3_2_lora" :
1080+ module_name = "llama3_2_lora"
1081+ model_class_name = "Llama3_2_Lora"
10411082 else :
10421083 raise ValueError (f"{ modelname } is not a valid Llama model." )
10431084 else :
@@ -1051,6 +1092,7 @@ def _load_llama_model(
10511092 model_class_name ,
10521093 checkpoint = checkpoint ,
10531094 checkpoint_dir = checkpoint_dir ,
1095+ adapter = adapter ,
10541096 params = params_path ,
10551097 use_kv_cache = use_kv_cache ,
10561098 use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
@@ -1066,6 +1108,7 @@ def _load_llama_model(
10661108 )
10671109 )
10681110
1111+ # breakpoint() # 3. OK.
10691112 return LLMEdgeManager (
10701113 model = model ,
10711114 modelname = modelname ,
@@ -1093,7 +1136,7 @@ def _load_llama_model(
10931136 model .max_seq_len ,
10941137 # pyre-fixme[6]: For 6th argument expected `ModelArgs` but got
10951138 # `Union[Tensor, Module]`.
1096- model . max_context_len ,
1139+ max_context_len ,
10971140 # pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
10981141 # Module]`.
10991142 model .n_layers ,
@@ -1244,6 +1287,9 @@ def _get_source_transforms( # noqa
12441287 if args .vulkan :
12451288 transforms .append (replace_with_vulkan_rotary_emb )
12461289
1290+ # transforms.append(
1291+ # replace_rope_with_inference_rope()
1292+ # )
12471293 return transforms
12481294
12491295
0 commit comments