Skip to content

Commit 3c6fbde

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
Refine the LLM manual (focus on the debugging and profiling part) (#2952)
Summary: Pull Request resolved: #2952 * Some auto-formatting by my VSCode (remove extra spaces) * Remove imports that have been imported in previous part of the doc * Other minor changes to keep consistency across the doc * Link a screenshot instead of using the raw table because the original table is illegible: {F1482781056} Differential Revision: D55938344
1 parent de7fdaa commit 3c6fbde

File tree

2 files changed

+50
-41
lines changed

2 files changed

+50
-41
lines changed
Loading

docs/source/llm/getting-started.md

+50-41
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
## Prerequisites
1616

17-
To follow this guide, you'll need to clone the ExecuTorch repository and install dependencies.
17+
To follow this guide, you'll need to clone the ExecuTorch repository and install dependencies.
1818
ExecuTorch recommends Python 3.10 and the use of Conda to manage your environment. Conda is not
1919
required, though be aware that you may need to replace the use of python/pip with python3/pip3
2020
depending on your environment.
@@ -82,7 +82,7 @@ For more information, see [Setting Up ExecuTorch](https://pytorch.org/executorch
8282

8383
## Running a Large Language Model Locally
8484

85-
This example uses Karpathy’s [NanoGPT](https://github.com/karpathy/nanoGPT), which is a minimal implementation of
85+
This example uses Karpathy’s [NanoGPT](https://github.com/karpathy/nanoGPT), which is a minimal implementation of
8686
GPT-2 124M. This guide is applicable to other language models, as ExecuTorch is model-invariant.
8787

8888
There are two steps to running a model with ExecuTorch:
@@ -129,7 +129,7 @@ Create a file called export_nanogpt.py with the following contents:
129129

130130
import torch
131131

132-
from executorch.exir import EdgeCompileConfig, to_edge
132+
from executorch.exir import EdgeCompileConfig, to_edge
133133
from torch.nn.attention import sdpa_kernel, SDPBackend
134134
from torch._export import capture_pre_autograd_graph
135135
from torch.export import export
@@ -139,7 +139,7 @@ from model import GPT
139139
# Load the model.
140140
model = GPT.from_pretrained('gpt2')
141141

142-
# Create example inputs. This is used in the export process to provide
142+
# Create example inputs. This is used in the export process to provide
143143
# hints on the expected shape of the model input.
144144
example_inputs = (torch.randint(0, 100, (1, 8), dtype=torch.long), )
145145

@@ -211,8 +211,8 @@ std::string generate(
211211
BasicSampler& sampler,
212212
size_t max_output_length) {
213213
214-
// Convert the input text into a list of integers (tokens) that represents
215-
// it, using the string-to-token mapping that the model was trained on.
214+
// Convert the input text into a list of integers (tokens) that represents
215+
// it, using the string-to-token mapping that the model was trained on.
216216
// Each token is an integer that represents a word or part of a word.
217217
std::vector<int64_t> input_tokens = tokenizer.encode(prompt);
218218
std::vector<int64_t> output_tokens;
@@ -221,8 +221,8 @@ std::string generate(
221221
// Convert the input_tokens from a vector of int64_t to EValue.
222222
// EValue is a unified data type in the ExecuTorch runtime.
223223
ManagedTensor tensor_tokens(
224-
input_tokens.data(),
225-
{1, static_cast<int>(input_tokens.size())},
224+
input_tokens.data(),
225+
{1, static_cast<int>(input_tokens.size())},
226226
ScalarType::Long);
227227
std::vector<EValue> inputs = {tensor_tokens.get_tensor()};
228228
@@ -232,7 +232,7 @@ std::string generate(
232232
// Convert the output logits from EValue to std::vector, which is what
233233
// the sampler expects.
234234
Tensor logits_tensor = logits_evalue.get()[0].toTensor();
235-
std::vector<float> logits(logits_tensor.data_ptr<float>(),
235+
std::vector<float> logits(logits_tensor.data_ptr<float>(),
236236
logits_tensor.data_ptr<float>() + logits_tensor.numel());
237237
238238
// Sample the next token from the logits.
@@ -255,9 +255,9 @@ std::string generate(
255255
}
256256
```
257257

258-
The `Module` class handles loading the .pte file and preparing for execution.
258+
The `Module` class handles loading the .pte file and preparing for execution.
259259

260-
The tokenizer is responsible for converting from a human-readable string representation of the prompt to the
260+
The tokenizer is responsible for converting from a human-readable string representation of the prompt to the
261261
numerical form expected by the model. To do this, the tokenzier associates short substrings with a given token ID.
262262
The tokens can be thought of as representing words or parts of words, though, in-practice, they may be arbitrary
263263
sequences of characters.
@@ -312,7 +312,7 @@ and the [ExecuTorch Runtime API Reference](https://pytorch.org/executorch/main/e
312312

313313
ExecuTorch uses the CMake build system. To compile and link against the ExecuTorch runtime,
314314
include the ExecuTorch project via `add_directory` and link against `executorch` and additional
315-
dependencies.
315+
dependencies.
316316

317317
Create a file named CMakeLists.txt with the following content:
318318

@@ -374,7 +374,7 @@ specific hardware (delegation), and because it is doing all of the calculations
374374

375375
## Delegation
376376

377-
While ExecuTorch provides a portable, cross-platform implementation for all operators, it also provides specialized
377+
While ExecuTorch provides a portable, cross-platform implementation for all operators, it also provides specialized
378378
backends for a number of different targets. These include, but are not limited to, x86 and ARM CPU acceleration via
379379
the XNNPACK backend, Apple acceleration via the CoreML backend and Metal Performance Shader (MPS) backend, and GPU
380380
acceleration via the Vulkan backend.
@@ -395,11 +395,10 @@ To delegate to the XNNPACK backend, call `to_backend` with an instance of `Xnnpa
395395

396396
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
397397
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
398-
from executorch.exir import EdgeCompileConfig, to_edge
399398

400399
#...
401400

402-
edge_config = edge_config = get_xnnpack_edge_compile_config()
401+
edge_config = get_xnnpack_edge_compile_config()
403402
edge_manager = to_edge(traced_model, compile_config=edge_config)
404403

405404
# Delegate to the XNNPACK backend.
@@ -433,15 +432,15 @@ and [CoreML Backend](https://pytorch.org/executorch/stable/build-run-coreml.html
433432
## Quantization
434433

435434
Quantization refers to a set of techniques for running calculations and storing tensors using lower precision types.
436-
Compared to 32-bit floating point, using 8-bit integers can provide both a significant speedup and reduction in
437-
memory usage. There are many approaches to quantizing a model, varying in amount of pre-processing required, data
435+
Compared to 32-bit floating point, using 8-bit integers can provide both a significant speedup and reduction in
436+
memory usage. There are many approaches to quantizing a model, varying in amount of pre-processing required, data
438437
types used, and impact on model accuracy and performance.
439438

440439
Because compute and memory are highly constrained on mobile devices, some form of quantization is necessary to ship
441440
large models on consumer electronics. In particular, large language models, such as Llama2, may require quantizing
442441
model weights to 4 bits or less.
443442

444-
Leveraging quantization requires transforming the model before export. PyTorch provides the pt2e (PyTorch 2 Export)
443+
Leveraging quantization requires transforming the model before export. PyTorch provides the pt2e (PyTorch 2 Export)
445444
API for this purpose. This example targets CPU acceleration using the XNNPACK delegate. As such, it needs to use the
446445
XNNPACK-specific quantizer. Targeting a different backend will require use of the corresponding quantizer.
447446

@@ -504,14 +503,14 @@ et_program = edge_manager.to_executorch()
504503
Finally, ensure that the runner links against the `xnnpack_backend` target in CMakeLists.txt.
505504

506505
```
507-
add_executable(nanogpt_runner nanogpt_runner.cpp)
506+
add_executable(nanogpt_runner main.cpp)
508507
target_link_libraries(
509508
nanogpt_runner
510509
PRIVATE
511-
etdump
512-
extension_module
513-
portable_ops_lib
514-
xnnpack_backend) # Link the XNNPACK backend
510+
executorch
511+
extension_module_static # Provides the Module class
512+
optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels
513+
xnnpack_backend) # Provides the XNNPACK CPU acceleration backend
515514
```
516515

517516
For more information, see [Quantization in ExecuTorch](https://pytorch.org/executorch/stable/quantization-overview.html).
@@ -530,6 +529,7 @@ The `get_delegation_info()` method provides a summary of what happened to the mo
530529
from executorch.exir.backend.utils import get_delegation_info
531530
from tabulate import tabulate
532531

532+
# ... After call to to_backend(), but before to_executorch()
533533
graph_module = edge_manager.exported_program().graph_module
534534
delegation_info = get_delegation_info(graph_module)
535535
print(delegation_info.get_summary())
@@ -564,7 +564,7 @@ from executorch.exir.backend.utils import print_delegated_graph
564564
graph_module = edge_manager.exported_program().graph_module
565565
print(print_delegated_graph(graph_module))
566566
```
567-
This may generate a large amount of output for large models. Consider using "Control+F" or "Command+F" to locate the operator you’re interested in
567+
This may generate a large amount of output for large models. Consider using "Control+F" or "Command+F" to locate the operator you’re interested in
568568
(e.g. “aten_view_copy_default”). Observe which instances are not under lowered graphs.
569569

570570
In the fragment of the output for NanoGPT below, observe that embedding and add operators are delegated to XNNPACK while the sub operator is not.
@@ -600,12 +600,12 @@ In your export script, after calling `to_edge()` and `to_executorch()`, call `ge
600600
import copy
601601

602602
# Make the deep copy immediately after to to_edge()
603-
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
603+
edge_manager_copy = copy.deepcopy(edge_manager)
604604

605605
# ...
606606
# Generate ETRecord right after to_executorch()
607607
etrecord_path = "etrecord.bin"
608-
generate_etrecord(etrecord_path, edge_program_manager_copy, et_program_manager)
608+
generate_etrecord(etrecord_path, edge_manager_copy, et_program)
609609
```
610610
611611
Run the export script and the ETRecord will be generated as `etrecord.bin`.
@@ -624,13 +624,14 @@ Include the ETDump header in your code.
624624
Create an Instance of the ETDumpGen class and pass it to the Module constructor.
625625
```cpp
626626
std::unique_ptr<torch::executor::ETDumpGen> etdump_gen_ = std::make_unique<torch::executor::ETDumpGen>();
627-
Module llm_model("nanogpt.pte", Module::MlockConfig::UseMlock, std::move(etdump_gen_));
627+
Module model("nanogpt.pte", torch::executor::Module::MlockConfig::UseMlockIgnoreErrors, std::move(etdump_gen_));
628628
```
629629
630-
After execution, save the ETDump to a file. You can capture multiple model runs in a single trace, if desired.
630+
After calling `generate()`, save the ETDump to a file. You can capture multiple
631+
model runs in a single trace, if desired.
631632
```cpp
632633
torch::executor::ETDumpGen* etdump_gen =
633-
static_cast<torch::executor::ETDumpGen*>(llm_model.event_tracer());
634+
static_cast<torch::executor::ETDumpGen*>(model.event_tracer());
634635
635636
ET_LOG(Info, "ETDump size: %zu blocks", etdump_gen->get_num_blocks());
636637
etdump_result result = etdump_gen->get_etdump_data();
@@ -643,9 +644,22 @@ if (result.buf != nullptr && result.size > 0) {
643644
}
644645
```
645646

646-
Compile the ExecuTorch runtime with the `ET_EVENT_TRACER_ENABLED` pre-processor flag to enable events to be traced and logged into ETDump inside the ExecuTorch runtime. Add these to your CMakeLists.txt
647+
Additionally, update CMakeLists.txt to build with SDK and enable events to be traced and logged into ETDump:
647648

648649
```
650+
option(EXECUTORCH_BUILD_SDK "" ON)
651+
652+
# ...
653+
654+
target_link_libraries(
655+
nanogpt_runner
656+
PRIVATE
657+
executorch
658+
extension_module_static # Provides the Module class
659+
optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels
660+
xnnpack_backend # Provides the XNNPACK CPU acceleration backend
661+
etdump) # Provides event tracing and logging
662+
649663
target_compile_options(executorch PUBLIC -DET_EVENT_TRACER_ENABLED)
650664
target_compile_options(portable_ops_lib PUBLIC -DET_EVENT_TRACER_ENABLED)
651665
```
@@ -658,20 +672,15 @@ Once you’ve collected debug artifacts ETDump (and optionally an ETRecord), you
658672
```python
659673
from executorch.sdk import Inspector
660674

661-
inspector = Inspector(etdump_path="etdump.etdp", etrecord="etrecord.bin")
662-
# If you did not generate an ETRecord, then just pass in the ETDump: `inspector = Inspector(etdump_path="etdump.etdp")`
675+
inspector = Inspector(etdump_path="etdump.etdp")
676+
# If you also generated an ETRecord, then pass that in as well: `inspector = Inspector(etdump_path="etdump.etdp", etrecord="etrecord.bin")`
663677

664-
inspector.print_data_tabular()
678+
with open("inspector_out.txt", "w") as file:
679+
inspector.print_data_tabular(file)
665680
```
666681
This prints the performance data in a tabular format in “inspector_out.txt”, with each row being a profiling event.
667-
668-
| | event_block_name | event_name | p10 (ms) | p50 (ms) | p90 (ms) | avg (ms) | min (ms) | max (ms) | op_types | is_delegated_op | delegate_backend_name |
669-
|---|----------------------|------------------|-----------|---------------|--------------|-------------|-------------|--------------|-------------|---------------------------|----------|
670-
| 0 | Default | Method::init | 60.502 | 60.502 | 60.502 | 60.502 | 60.502 | 60.502 | [] | False | |
671-
| 1 | Default | Program::load_method | 60.5114 | 60.5114 | 60.5114 | 60.5114 | 60.5114 | 60.5114 | [] | False | |
672-
| 2 | Execute | native_call_arange.start_out | 0.029583 | 0.029583 | 0.029583 | 0.029583 | 0.029583 | 0.029583 | [] | False | |
673-
| 3 | Execute | native_call_embedding.out | 0.022916 | 0.022916 | 0.022916 | 0.022916 | 0.022916 | 0.022916 | [] | False | |
674-
| 4 | Execute | native_call_embedding.out | 0.001084 | 0.001084 | 0.001084 | 0.001084 | 0.001084 | 0.001084 | [] | False | |
682+
![](../_static/img/llm_manual_print_data_tabular.png)
683+
<a href="../_static/img/llm_manual_print_data_tabular.png" target="_blank">View in full size</a>
675684

676685
To learn more about the Inspector and the rich functionality it provides, see the [Inspector API Reference](https://pytorch.org/executorch/main/sdk-inspector.html).
677686

0 commit comments

Comments
 (0)