Skip to content

Commit 99968b9

Browse files
dpalmasanfacebook-github-bot
authored andcommitted
Tuning LLM from PTE (#5233)
Summary: Pull Request resolved: #5233 * Add example of finetuning using executorch Reviewed By: JacobSzwejbka, dvorjackz Differential Revision: D61689035
1 parent d80f78f commit 99968b9

File tree

10 files changed

+712
-15
lines changed

10 files changed

+712
-15
lines changed

examples/llm_pte_finetuning/TARGETS

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
2+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
3+
4+
oncall("papaya_oncall")
5+
6+
python_library(
7+
name = "model_loading_lib",
8+
srcs = [
9+
"model_loading_lib.py",
10+
],
11+
deps = [
12+
"fbcode//caffe2:torch",
13+
"fbcode//executorch/exir:lib",
14+
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
15+
"fbcode//pytorch/torchtune:lib",
16+
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
17+
"fbsource//third-party/pypi/omegaconf:omegaconf",
18+
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
19+
],
20+
)
21+
22+
python_library(
23+
name = "training_lib",
24+
srcs = [
25+
"training_lib.py",
26+
],
27+
deps = [
28+
"fbcode//caffe2:torch",
29+
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
30+
"fbcode//pytorch/torchtune:lib",
31+
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
32+
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
33+
"fbsource//third-party/pypi/tqdm:tqdm",
34+
],
35+
)
36+
37+
python_binary(
38+
name = "runner",
39+
srcs = [
40+
"runner.py",
41+
],
42+
main_function = "executorch.examples.llm_pte_finetuning.runner.main",
43+
deps = [
44+
"fbcode//caffe2:torch",
45+
"fbcode//executorch/examples/llm_pte_finetuning:training_lib",
46+
"fbcode//pytorch/torchtune:lib",
47+
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
48+
"fbsource//third-party/pypi/omegaconf:omegaconf",
49+
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
50+
"fbsource//third-party/pypi/tqdm:tqdm",
51+
],
52+
)
53+
54+
python_binary(
55+
name = "model_exporter",
56+
srcs = [
57+
"model_exporter.py",
58+
],
59+
main_function = "executorch.examples.llm_pte_finetuning.model_exporter.main",
60+
deps = [
61+
"fbcode//caffe2:torch",
62+
"fbcode//executorch/examples/llm_pte_finetuning:model_loading_lib", # @manual for model loading
63+
"fbcode//executorch/examples/llm_pte_finetuning:training_lib", # @manual for model exporting
64+
"fbcode//pytorch/torchtune:lib",
65+
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
66+
"fbsource//third-party/pypi/omegaconf:omegaconf",
67+
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
68+
],
69+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import argparse
10+
11+
import torch
12+
from executorch.examples.llm_pte_finetuning.model_loading_lib import (
13+
export_model_lora_training,
14+
load_checkpoint,
15+
setup_model,
16+
)
17+
18+
from executorch.examples.llm_pte_finetuning.training_lib import (
19+
get_dataloader,
20+
TrainingModule,
21+
)
22+
23+
from omegaconf import OmegaConf
24+
from torch.nn import functional as F
25+
from torchtune import config
26+
27+
from torchtune.training import MODEL_KEY
28+
29+
parser = argparse.ArgumentParser(
30+
prog="ModelExporter",
31+
description="Export a LoRA model to ExecuTorch.",
32+
epilog="Model exported to be used for fine-tuning.",
33+
)
34+
35+
parser.add_argument("--cfg", type=str, help="Path to the config file.")
36+
parser.add_argument("--output_file", type=str, help="Path to the output ET model.")
37+
38+
39+
def main() -> None:
40+
args = parser.parse_args()
41+
config_file = args.cfg
42+
output_file = args.output_file
43+
cfg = OmegaConf.load(config_file)
44+
tokenizer = config.instantiate(
45+
cfg.tokenizer,
46+
)
47+
48+
loss_fn = config.instantiate(cfg.loss)
49+
50+
ds = config.instantiate(cfg.dataset, tokenizer)
51+
train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
52+
train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)
53+
54+
max_seq_len = cfg.tokenizer.max_seq_len
55+
56+
# Example inputs, needed for ET export.
57+
batch = next(iter(train_dataloader))
58+
tokens, labels = batch["tokens"], batch["labels"]
59+
token_size = tokens.shape[1]
60+
labels_size = labels.shape[1]
61+
62+
if token_size > max_seq_len:
63+
tokens = tokens[:, :max_seq_len]
64+
else:
65+
tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0)
66+
67+
if labels_size > max_seq_len:
68+
labels = labels[:, :max_seq_len]
69+
else:
70+
labels = F.pad(labels, (0, max_seq_len - labels_size), value=0)
71+
72+
# Load pre-trained checkpoint.
73+
checkpoint_dict = load_checkpoint(cfg=cfg)
74+
model = setup_model(
75+
# pyre-ignore
76+
cfg=cfg,
77+
base_model_state_dict=checkpoint_dict[MODEL_KEY],
78+
)
79+
80+
training_module = TrainingModule(model, loss_fn)
81+
82+
# Export the model to ExecuTorch for training.
83+
export_model_lora_training(training_module, (tokens, labels), output_file)
84+
85+
86+
if __name__ == "__main__":
87+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Any, Dict, Tuple
10+
11+
import torch
12+
from executorch.examples.llm_pte_finetuning.training_lib import TrainingModule
13+
from executorch.exir import to_edge
14+
15+
from omegaconf import DictConfig
16+
from torch.export import export, ExportedProgram
17+
from torch.export.experimental import _export_forward_backward
18+
from torch.nn.attention import sdpa_kernel, SDPBackend
19+
from torchtune import config
20+
from torchtune.modules.peft import get_adapter_params, set_trainable_params
21+
from torchtune.training.precision import get_dtype, set_default_dtype
22+
from torchtune.utils._device import get_device
23+
24+
25+
def load_checkpoint(cfg: Any) -> Dict[str, Any]: # pyre-ignore[2]
26+
"""
27+
Extract the checkpoint state from file and validate. This includes the
28+
base model weights. If resume_from_checkpoint is True, this also includes
29+
the adapter weights and recipe state
30+
"""
31+
checkpointer = config.instantiate(
32+
cfg.checkpointer,
33+
resume_from_checkpoint=cfg.resume_from_checkpoint,
34+
)
35+
checkpoint_dict = checkpointer.load_checkpoint()
36+
return checkpoint_dict
37+
38+
39+
def setup_model(
40+
cfg: DictConfig,
41+
base_model_state_dict: Dict[str, Any],
42+
) -> torch.nn.Module:
43+
device = get_device(device=cfg.device)
44+
dtype = get_dtype(cfg.dtype, device=device)
45+
with set_default_dtype(dtype), device:
46+
model = config.instantiate(cfg.model)
47+
48+
adapter_params = get_adapter_params(model)
49+
set_trainable_params(model, adapter_params)
50+
model.load_state_dict(base_model_state_dict, strict=False)
51+
return model
52+
53+
54+
def export_model_lora_training(
55+
model: TrainingModule, example_args: Tuple[Any, ...], output_file: str # pyre-ignore[2]
56+
) -> None:
57+
"""
58+
Export model with LoRA model to executorch for training, only.
59+
"""
60+
61+
# 0. Mark the LoRA layers as trainable (requires_grad = True) in order
62+
# to just export the backwards pass for these layers later in the
63+
# export process.
64+
set_trainable_params(model, get_adapter_params(model))
65+
66+
print("Exporting model with LoRA for training")
67+
# 1. torch.export: Defines the program with the ATen operator set.
68+
69+
with sdpa_kernel([SDPBackend.MATH]):
70+
exported_graph: ExportedProgram = export(model, example_args, strict=False)
71+
print("Creating a joint forward-backwards graph for training")
72+
joint_graph = _export_forward_backward(exported_graph)
73+
74+
# 2. to_edge: Make optimizations for Edge devices.
75+
print("Lowering to edge dialect")
76+
edge_program = to_edge(joint_graph)
77+
78+
print(edge_program._edge_programs["forward"].graph_module)
79+
80+
# 3. to_executorch: Convert the graph to an ExecuTorch program.
81+
print("Exporting to executorch")
82+
executorch_program = edge_program.to_executorch()
83+
print(executorch_program.exported_program().graph_signature)
84+
print(f"Saving to {output_file}")
85+
with open(output_file, "wb") as file:
86+
file.write(executorch_program.buffer)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
tokenizer:
2+
_component_: torchtune.models.phi3.phi3_mini_tokenizer
3+
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
4+
max_seq_len: 1024
5+
6+
dataset:
7+
_component_: torchtune.datasets.instruct_dataset
8+
template: papaya.toolkit.experimental.llm_pte_finetuning.utils.DatabricksDolly
9+
source: iamtarun/python_code_instructions_18k_alpaca
10+
split: train
11+
column_map:
12+
instruction: instruction
13+
prompt: prompt
14+
input: input
15+
output: output
16+
seed: null
17+
shuffle: True
18+
batch_size: 1
19+
20+
loss:
21+
_component_: torch.nn.CrossEntropyLoss
22+
23+
model:
24+
_component_: torchtune.models.phi3.lora_phi3_mini
25+
lora_attn_modules: ['q_proj', 'v_proj']
26+
apply_lora_to_mlp: False
27+
apply_lora_to_output: False
28+
lora_rank: 8
29+
lora_alpha: 16
30+
31+
checkpointer:
32+
_component_: torchtune.training.FullModelHFCheckpointer
33+
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
34+
checkpoint_files: [
35+
model-00001-of-00002.safetensors,
36+
model-00002-of-00002.safetensors
37+
]
38+
recipe_checkpoint: null
39+
output_dir: /tmp/Phi-3-mini-4k-instruct/
40+
model_type: PHI3_MINI
41+
42+
resume_from_checkpoint: False
43+
save_adapter_weights_only: False
44+
45+
device: cpu
46+
dtype: fp32
47+
48+
enable_activation_checkpointing: True
49+
compile: False
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
tokenizer:
2+
_component_: torchtune.models.phi3.phi3_mini_tokenizer
3+
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
4+
max_seq_len: 512
5+
6+
dataset:
7+
_component_: torchtune.datasets.alpaca_cleaned_dataset
8+
seed: null
9+
shuffle: True
10+
batch_size: 1
11+
12+
loss:
13+
_component_: torch.nn.CrossEntropyLoss
14+
15+
model:
16+
_component_: torchtune.models.phi3.lora_phi3_mini
17+
lora_attn_modules: ['q_proj', 'v_proj']
18+
apply_lora_to_mlp: False
19+
apply_lora_to_output: False
20+
lora_rank: 8
21+
lora_alpha: 16
22+
23+
checkpointer:
24+
_component_: torchtune.training.FullModelHFCheckpointer
25+
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
26+
checkpoint_files: [
27+
model-00001-of-00002.safetensors,
28+
model-00002-of-00002.safetensors
29+
]
30+
recipe_checkpoint: null
31+
output_dir: /tmp/Phi-3-mini-4k-instruct/
32+
model_type: PHI3_MINI
33+
resume_from_checkpoint: False
34+
save_adapter_weights_only: False
35+
36+
device: cpu
37+
dtype: fp32
38+
39+
enable_activation_checkpointing: True
40+
compile: False
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
tokenizer:
2+
_component_: torchtune.models.qwen2.qwen2_tokenizer
3+
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
4+
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
5+
max_seq_len: 512
6+
7+
dataset:
8+
_component_: torchtune.datasets.alpaca_cleaned_dataset
9+
seed: null
10+
shuffle: True
11+
batch_size: 1
12+
13+
loss:
14+
_component_: torch.nn.CrossEntropyLoss
15+
16+
model:
17+
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
18+
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
19+
apply_lora_to_mlp: False
20+
lora_rank: 32
21+
lora_alpha: 64
22+
23+
checkpointer:
24+
_component_: torchtune.training.FullModelHFCheckpointer
25+
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
26+
checkpoint_files: [
27+
model.safetensors
28+
]
29+
recipe_checkpoint: null
30+
output_dir: /tmp/Qwen2-0.5B-Instruct
31+
model_type: QWEN2
32+
resume_from_checkpoint: False
33+
save_adapter_weights_only: False
34+
35+
device: cpu
36+
dtype: fp32
37+
38+
enable_activation_checkpointing: True
39+
compile: False

0 commit comments

Comments
 (0)