Skip to content

Commit f95d42f

Browse files
authored
Support llama3 training
Differential Revision: D70977130 Pull Request resolved: #9149
1 parent be92fb4 commit f95d42f

File tree

3 files changed

+111
-4
lines changed

3 files changed

+111
-4
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
output_dir: /tmp/llama-3.2-1B_ft-output # /tmp may be deleted by your system. Change it to your preference.
2+
3+
# Model Arguments
4+
model:
5+
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
6+
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
7+
apply_lora_to_mlp: True
8+
lora_rank: 64 # higher increases accuracy and memory
9+
lora_alpha: 128 # usually alpha=2*rank
10+
lora_dropout: 0.0
11+
12+
# Tokenizer
13+
tokenizer:
14+
_component_: torchtune.models.llama3.llama3_tokenizer
15+
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
16+
special_tokens_path: null
17+
max_seq_len: 512
18+
prompt_template: null
19+
20+
# Dataset
21+
dataset:
22+
_component_: torchtune.datasets.alpaca_dataset
23+
packed: False # True increases speed
24+
seed: null
25+
shuffle: True
26+
27+
checkpointer:
28+
_component_: torchtune.training.FullModelHFCheckpointer
29+
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
30+
checkpoint_files: [
31+
model.safetensors
32+
]
33+
recipe_checkpoint: null
34+
output_dir: ${output_dir}
35+
model_type: LLAMA3_2
36+
resume_from_checkpoint: False
37+
38+
# Fine-tuning arguments
39+
batch_size: 1
40+
epochs: 1
41+
optimizer:
42+
_component_: torch.optim.AdamW
43+
fused: True
44+
weight_decay: 0.01
45+
lr: 3e-4
46+
47+
loss:
48+
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
49+
max_steps_per_epoch: null
50+
gradient_accumulation_steps: 1 # Use to increase effective batch size
51+
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
52+
clip_grad_norm: null
53+
compile: False # torch.compile the model + loss, True increases speed + decreases memory
54+
55+
# Training environment
56+
device: cpu
57+
dtype: fp32
58+
59+
# Memory management
60+
enable_activation_checkpointing: False # True reduces memory
61+
enable_activation_offloading: False # True reduces memory
62+
63+
# Logging
64+
metric_logger:
65+
_component_: torchtune.training.metric_logging.DiskLogger
66+
log_dir: ${output_dir}/logs
67+
log_every_n_steps: 1
68+
log_peak_memory_stats: True
69+
70+
71+
# Profiler (disabled)
72+
profiler:
73+
_component_: torchtune.training.setup_torch_profiler
74+
enabled: False
75+
76+
#Output directory of trace artifacts
77+
output_dir: ${output_dir}/profiling_outputs
78+
79+
#`torch.profiler.ProfilerActivity` types to trace
80+
cpu: True
81+
cuda: True
82+
83+
#trace options passed to `torch.profiler.profile`
84+
profile_memory: False
85+
with_stack: False
86+
record_shapes: True
87+
with_flops: False
88+
89+
# `torch.profiler.schedule` options:
90+
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
91+
wait_steps: 5
92+
warmup_steps: 3
93+
active_steps: 2
94+
num_cycles: 1

examples/llm_pte_finetuning/model_exporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def main() -> None:
4848
loss_fn = config.instantiate(cfg.loss)
4949

5050
ds = config.instantiate(cfg.dataset, tokenizer)
51-
train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
51+
train_set, _ = torch.utils.data.random_split(ds, [0.8, 0.2])
5252
train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)
5353

5454
max_seq_len = cfg.tokenizer.max_seq_len

examples/llm_pte_finetuning/training_lib.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,26 @@ def __init__(
3636
super().__init__()
3737
self.model = model
3838
self.loss = loss
39+
if loss.__class__.__name__ == "CEWithChunkedOutputLoss":
40+
# set num_output_chunks for model
41+
# pyre-ignore
42+
model.set_num_output_chunks(self.loss.num_output_chunks)
43+
44+
# (batch_size, 1) tensor of ignore_index
45+
# pyre-ignore
46+
self.ignore_labels_cache = torch.full(
47+
(1, 1), self.loss.ignore_index, device="cpu" # pyre-ignore
48+
)
3949

4050
def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
4151
# Output is of the shape (seq_len, vocab_size).
4252
logits = self.model(input)
43-
logits = logits[..., :-1, :].contiguous()
44-
labels = labels[..., 1:].contiguous()
45-
logits = logits.transpose(1, 2)
53+
labels = torch.hstack(
54+
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
55+
)
56+
if not isinstance(logits, list):
57+
labels = labels.reshape(-1)
58+
logits = logits.reshape(-1, logits.size(-1))
4659
return self.loss(logits, labels)
4760

4861

0 commit comments

Comments
 (0)