diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 8e02c15b21c..51a2060c328 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -620,12 +620,28 @@ def save_bpte_program(exec_prog, original_model: torch.nn.Module, output_name: s save_bundled_program(exec_prog, method_test_suites, output_name) +def quantize_model( + exported_program, args, model: torch.nn.Module, example_inputs, compile_spec +): + model_int8 = quantize( + model, + args.model_name, + compile_spec, + example_inputs, + args.evaluate, + args.evaluate_config, + ) + # Wrap quantized model back into an exported_program + exported_program = torch.export.export_for_training( + model_int8, example_inputs, strict=True + ) + + return model_int8, exported_program + + def to_edge_TOSA_delegate( - exported_program, - args, - model: torch.nn.Module, + exported_program, args, model: torch.nn.Module, example_inputs ): - model_int8 = None # As we can target multiple output encodings, one must # be specified. compile_spec = get_compile_spec( @@ -634,23 +650,13 @@ def to_edge_TOSA_delegate( args.system_config, args.memory_mode, ) + + model_int8 = None if args.quantize: - model = quantize( - model, - args.model_name, - compile_spec, - example_inputs, - args.evaluate, - args.evaluate_config, + model_int8, exported_program = quantize_model( + exported_program, args, model, example_inputs, compile_spec ) - model_int8 = model - # Wrap quantized model back into an exported_program - exported_program = torch.export.export_for_training( - model, example_inputs, strict=True - ) - - if args.intermediates: - os.makedirs(args.intermediates, exist_ok=True) + model = model_int8 if is_ethosu(compile_spec): partitioner = EthosUPartitioner(compile_spec) @@ -669,6 +675,31 @@ def to_edge_TOSA_delegate( return model_int8, edge +def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_inputs): + model_int8 = None + if args.quantize: + # As we can target multiple output encodings, one must + # be specified. + compile_spec = get_compile_spec( + args.target, + args.intermediates, + args.system_config, + args.memory_mode, + ) + model, exported_program = quantize_model( + exported_program, args, model, example_inputs, compile_spec + ) + model_int8 = model + + edge = to_edge_transform_and_lower( + exported_program, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + ), + ) + return model_int8, edge + + if __name__ == "__main__": # noqa: C901 args = get_args() @@ -686,16 +717,18 @@ def to_edge_TOSA_delegate( model = exported_program.module() model_fp32 = model + if args.intermediates: + os.makedirs(args.intermediates, exist_ok=True) + # Quantize if required model_int8 = None if args.delegate: - model_int8, edge = to_edge_TOSA_delegate(exported_program, args, model) + model_int8, edge = to_edge_TOSA_delegate( + exported_program, args, model, example_inputs + ) else: - edge = to_edge_transform_and_lower( - exported_program, - compile_config=EdgeCompileConfig( - _check_ir_validity=False, - ), + model_int8, edge = to_edge_no_delegate( + exported_program, args, model, example_inputs ) dump_delegation_info(edge, args.intermediates)