|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | + |
| 9 | +from typing import Any, Dict, Tuple |
| 10 | + |
| 11 | +import torch |
| 12 | +from executorch.examples.llm_pte_finetuning.training_lib import TrainingModule |
| 13 | +from executorch.exir import to_edge |
| 14 | + |
| 15 | +from omegaconf import DictConfig |
| 16 | +from torch.export import export, ExportedProgram |
| 17 | +from torch.export.experimental import _export_forward_backward |
| 18 | +from torch.nn.attention import sdpa_kernel, SDPBackend |
| 19 | +from torchtune import config |
| 20 | +from torchtune.modules.peft import get_adapter_params, set_trainable_params |
| 21 | +from torchtune.training.precision import get_dtype, set_default_dtype |
| 22 | +from torchtune.utils._device import get_device |
| 23 | + |
| 24 | + |
| 25 | +def load_checkpoint(cfg: Any) -> Dict[str, Any]: # pyre-ignore[2] |
| 26 | + """ |
| 27 | + Extract the checkpoint state from file and validate. This includes the |
| 28 | + base model weights. If resume_from_checkpoint is True, this also includes |
| 29 | + the adapter weights and recipe state |
| 30 | + """ |
| 31 | + checkpointer = config.instantiate( |
| 32 | + cfg.checkpointer, |
| 33 | + resume_from_checkpoint=cfg.resume_from_checkpoint, |
| 34 | + ) |
| 35 | + checkpoint_dict = checkpointer.load_checkpoint() |
| 36 | + return checkpoint_dict |
| 37 | + |
| 38 | + |
| 39 | +def setup_model( |
| 40 | + cfg: DictConfig, |
| 41 | + base_model_state_dict: Dict[str, Any], |
| 42 | +) -> torch.nn.Module: |
| 43 | + device = get_device(device=cfg.device) |
| 44 | + dtype = get_dtype(cfg.dtype, device=device) |
| 45 | + with set_default_dtype(dtype), device: |
| 46 | + model = config.instantiate(cfg.model) |
| 47 | + |
| 48 | + adapter_params = get_adapter_params(model) |
| 49 | + set_trainable_params(model, adapter_params) |
| 50 | + model.load_state_dict(base_model_state_dict, strict=False) |
| 51 | + return model |
| 52 | + |
| 53 | + |
| 54 | +def export_model_lora_training( |
| 55 | + model: TrainingModule, example_args: Tuple[Any, ...], output_file: str # pyre-ignore[2] |
| 56 | +) -> None: |
| 57 | + """ |
| 58 | + Export model with LoRA model to executorch for training, only. |
| 59 | + """ |
| 60 | + |
| 61 | + # 0. Mark the LoRA layers as trainable (requires_grad = True) in order |
| 62 | + # to just export the backwards pass for these layers later in the |
| 63 | + # export process. |
| 64 | + set_trainable_params(model, get_adapter_params(model)) |
| 65 | + |
| 66 | + print("Exporting model with LoRA for training") |
| 67 | + # 1. torch.export: Defines the program with the ATen operator set. |
| 68 | + |
| 69 | + with sdpa_kernel([SDPBackend.MATH]): |
| 70 | + exported_graph: ExportedProgram = export(model, example_args, strict=False) |
| 71 | + print("Creating a joint forward-backwards graph for training") |
| 72 | + joint_graph = _export_forward_backward(exported_graph) |
| 73 | + |
| 74 | + # 2. to_edge: Make optimizations for Edge devices. |
| 75 | + print("Lowering to edge dialect") |
| 76 | + edge_program = to_edge(joint_graph) |
| 77 | + |
| 78 | + print(edge_program._edge_programs["forward"].graph_module) |
| 79 | + |
| 80 | + # 3. to_executorch: Convert the graph to an ExecuTorch program. |
| 81 | + print("Exporting to executorch") |
| 82 | + executorch_program = edge_program.to_executorch() |
| 83 | + print(executorch_program.exported_program().graph_signature) |
| 84 | + print(f"Saving to {output_file}") |
| 85 | + with open(output_file, "wb") as file: |
| 86 | + file.write(executorch_program.buffer) |
0 commit comments