Skip to content

Tuning LLM from PTE #5233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions examples/llm_pte_finetuning/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("papaya_oncall")

python_library(
name = "model_loading_lib",
srcs = [
"model_loading_lib.py",
],
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/examples/llm_pte_finetuning:training_lib",
"fbcode//executorch/exir:lib",
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
"fbsource//third-party/pypi/omegaconf:omegaconf",
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
],
)

python_library(
name = "training_lib",
srcs = [
"training_lib.py",
],
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
"fbsource//third-party/pypi/tqdm:tqdm",
],
)

python_binary(
name = "runner",
srcs = [
"runner.py",
],
main_function = "executorch.examples.llm_pte_finetuning.runner.main",
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/examples/llm_pte_finetuning:training_lib",
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
"fbsource//third-party/pypi/omegaconf:omegaconf",
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
"fbsource//third-party/pypi/tqdm:tqdm",
],
)

python_binary(
name = "model_exporter",
srcs = [
"model_exporter.py",
],
main_function = "executorch.examples.llm_pte_finetuning.model_exporter.main",
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/examples/llm_pte_finetuning:model_loading_lib", # @manual for model loading
"fbcode//executorch/examples/llm_pte_finetuning:training_lib", # @manual for model exporting
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
"fbsource//third-party/pypi/omegaconf:omegaconf",
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
],
)
87 changes: 87 additions & 0 deletions examples/llm_pte_finetuning/model_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import argparse

import torch
from executorch.examples.llm_pte_finetuning.model_loading_lib import (
export_model_lora_training,
load_checkpoint,
setup_model,
)

from executorch.examples.llm_pte_finetuning.training_lib import (
get_dataloader,
TrainingModule,
)

from omegaconf import OmegaConf
from torch.nn import functional as F
from torchtune import config

from torchtune.training import MODEL_KEY

parser = argparse.ArgumentParser(
prog="ModelExporter",
description="Export a LoRA model to ExecuTorch.",
epilog="Model exported to be used for fine-tuning.",
)

parser.add_argument("--cfg", type=str, help="Path to the config file.")
parser.add_argument("--output_file", type=str, help="Path to the output ET model.")


def main() -> None:
args = parser.parse_args()
config_file = args.cfg
output_file = args.output_file
cfg = OmegaConf.load(config_file)
tokenizer = config.instantiate(
cfg.tokenizer,
)

loss_fn = config.instantiate(cfg.loss)

ds = config.instantiate(cfg.dataset, tokenizer)
train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)

max_seq_len = cfg.tokenizer.max_seq_len

# Example inputs, needed for ET export.
batch = next(iter(train_dataloader))
tokens, labels = batch["tokens"], batch["labels"]
token_size = tokens.shape[1]
labels_size = labels.shape[1]

if token_size > max_seq_len:
tokens = tokens[:, :max_seq_len]
else:
tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0)

if labels_size > max_seq_len:
labels = labels[:, :max_seq_len]
else:
labels = F.pad(labels, (0, max_seq_len - labels_size), value=0)

# Load pre-trained checkpoint.
checkpoint_dict = load_checkpoint(cfg=cfg)
model = setup_model(
# pyre-ignore
cfg=cfg,
base_model_state_dict=checkpoint_dict[MODEL_KEY],
)

training_module = TrainingModule(model, loss_fn)

# Export the model to ExecuTorch for training.
export_model_lora_training(training_module, (tokens, labels), output_file)


if __name__ == "__main__":
main()
88 changes: 88 additions & 0 deletions examples/llm_pte_finetuning/model_loading_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Dict, Tuple

import torch
from executorch.examples.llm_pte_finetuning.training_lib import TrainingModule
from executorch.exir import to_edge

from omegaconf import DictConfig
from torch.export import export, ExportedProgram
from torch.export.experimental import _export_forward_backward
from torch.nn.attention import sdpa_kernel, SDPBackend
from torchtune import config
from torchtune.modules.peft import get_adapter_params, set_trainable_params
from torchtune.training.precision import get_dtype, set_default_dtype
from torchtune.utils._device import get_device


def load_checkpoint(cfg: Any) -> Dict[str, Any]: # pyre-ignore[2]
"""
Extract the checkpoint state from file and validate. This includes the
base model weights. If resume_from_checkpoint is True, this also includes
the adapter weights and recipe state
"""
checkpointer = config.instantiate(
cfg.checkpointer,
resume_from_checkpoint=cfg.resume_from_checkpoint,
)
checkpoint_dict = checkpointer.load_checkpoint()
return checkpoint_dict


def setup_model(
cfg: DictConfig,
base_model_state_dict: Dict[str, Any],
) -> torch.nn.Module:
device = get_device(device=cfg.device)
dtype = get_dtype(cfg.dtype, device=device)
with set_default_dtype(dtype), device:
model = config.instantiate(cfg.model)

adapter_params = get_adapter_params(model)
set_trainable_params(model, adapter_params)
model.load_state_dict(base_model_state_dict, strict=False)
return model


def export_model_lora_training(
model: TrainingModule,
example_args: Tuple[Any, ...], # pyre-ignore[2]
output_file: str,
) -> None:
"""
Export model with LoRA model to executorch for training, only.
"""

# 0. Mark the LoRA layers as trainable (requires_grad = True) in order
# to just export the backwards pass for these layers later in the
# export process.
set_trainable_params(model, get_adapter_params(model))

print("Exporting model with LoRA for training")
# 1. torch.export: Defines the program with the ATen operator set.

with sdpa_kernel([SDPBackend.MATH]):
exported_graph: ExportedProgram = export(model, example_args, strict=False)
print("Creating a joint forward-backwards graph for training")
joint_graph = _export_forward_backward(exported_graph)

# 2. to_edge: Make optimizations for Edge devices.
print("Lowering to edge dialect")
edge_program = to_edge(joint_graph)

print(edge_program._edge_programs["forward"].graph_module)

# 3. to_executorch: Convert the graph to an ExecuTorch program.
print("Exporting to executorch")
executorch_program = edge_program.to_executorch()
print(executorch_program.exported_program().graph_signature)
print(f"Saving to {output_file}")
with open(output_file, "wb") as file:
file.write(executorch_program.buffer)
49 changes: 49 additions & 0 deletions examples/llm_pte_finetuning/phi3_alpaca_code_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
tokenizer:
_component_: torchtune.models.phi3.phi3_mini_tokenizer
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
max_seq_len: 1024

dataset:
_component_: torchtune.datasets.instruct_dataset
template: papaya.toolkit.experimental.llm_pte_finetuning.utils.DatabricksDolly
source: iamtarun/python_code_instructions_18k_alpaca
split: train
column_map:
instruction: instruction
prompt: prompt
input: input
output: output
seed: null
shuffle: True
batch_size: 1

loss:
_component_: torch.nn.CrossEntropyLoss

model:
_component_: torchtune.models.phi3.lora_phi3_mini
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Phi-3-mini-4k-instruct/
model_type: PHI3_MINI

resume_from_checkpoint: False
save_adapter_weights_only: False

device: cpu
dtype: fp32

enable_activation_checkpointing: True
compile: False
40 changes: 40 additions & 0 deletions examples/llm_pte_finetuning/phi3_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
tokenizer:
_component_: torchtune.models.phi3.phi3_mini_tokenizer
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
max_seq_len: 512

dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
seed: null
shuffle: True
batch_size: 1

loss:
_component_: torch.nn.CrossEntropyLoss

model:
_component_: torchtune.models.phi3.lora_phi3_mini
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Phi-3-mini-4k-instruct/
model_type: PHI3_MINI
resume_from_checkpoint: False
save_adapter_weights_only: False

device: cpu
dtype: fp32

enable_activation_checkpointing: True
compile: False
39 changes: 39 additions & 0 deletions examples/llm_pte_finetuning/qwen_05b_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
max_seq_len: 512

dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
seed: null
shuffle: True
batch_size: 1

loss:
_component_: torch.nn.CrossEntropyLoss

model:
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: False
lora_rank: 32
lora_alpha: 64

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Qwen2-0.5B-Instruct
model_type: QWEN2
resume_from_checkpoint: False
save_adapter_weights_only: False

device: cpu
dtype: fp32

enable_activation_checkpointing: True
compile: False
Loading
Loading