Skip to content

Commit b6e54d0

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
move code under executorch/example (#3176)
Summary: Pull Request resolved: #3176 This diff moves llm manual code from outside github (Dave's and Georgey's) to executorch codebase for better pointing to. After this diff. //executorch/examples/llm_maunal will become the only source of truth of our llm manual code. Reviewed By: byjlw, dbort Differential Revision: D56365058 fbshipit-source-id: 97280fc0ca955caabb6056cddbb72102ed711f2c
1 parent 45fd796 commit b6e54d0

File tree

7 files changed

+489
-0
lines changed

7 files changed

+489
-0
lines changed

examples/llm_manual/CMakeLists.txt

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.19)
8+
project(nanogpt_runner)
9+
10+
set(CMAKE_CXX_STANDARD 17)
11+
set(CMAKE_CXX_STANDARD_REQUIRED True)
12+
13+
# Set options for executorch build.
14+
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
15+
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
16+
option(EXECUTORCH_BUILD_OPTIMIZED "" ON)
17+
option(EXECUTORCH_BUILD_XNNPACK "" ON) # Build with Xnnpack backend
18+
19+
# Include the executorch subdirectory.
20+
add_subdirectory(
21+
${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch
22+
${CMAKE_BINARY_DIR}/executorch)
23+
24+
# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
25+
26+
add_executable(nanogpt_runner main.cpp)
27+
target_link_libraries(
28+
nanogpt_runner
29+
PRIVATE
30+
executorch
31+
extension_module_static # Provides the Module class
32+
optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels
33+
xnnpack_backend) # Provides the XNNPACK CPU acceleration backend

examples/llm_manual/README.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# LLM Manual
2+
3+
This repository is a storage place for the files that [LLM Maunal](https://pytorch.org/executorch/main/llm/getting-started.html) needs. Please refer to the documentation website for more information.

examples/llm_manual/basic_sampler.h

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <algorithm>
10+
#include <vector>
11+
class BasicSampler {
12+
public:
13+
BasicSampler() {}
14+
int64_t sample(std::vector<float> logits) {
15+
// Find the token with the highest log probability.
16+
int64_t max_index =
17+
std::max_element(logits.begin(), logits.end()) - logits.begin();
18+
return max_index;
19+
}
20+
};

examples/llm_manual/basic_tokenizer.h

+192
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <fstream>
10+
#include <iostream>
11+
#include <sstream>
12+
#include <string>
13+
#include <unordered_map>
14+
#include <vector>
15+
16+
class BasicTokenizer {
17+
public:
18+
BasicTokenizer(const std::string& filePath) {
19+
std::ifstream file(filePath);
20+
21+
if (!file) {
22+
std::cerr << "Unable to open file";
23+
exit(9); // return with error code
24+
}
25+
std::string str(
26+
(std::istreambuf_iterator<char>(file)),
27+
std::istreambuf_iterator<char>());
28+
29+
size_t i = 0u;
30+
i = consume_whitespace(str, i);
31+
i = expect(str, i, '{');
32+
33+
while (i < str.size() && str[i] != '}') {
34+
i = consume_field(str, i);
35+
}
36+
37+
// Build decode map as inverse of encode.
38+
for (auto& i : encode_) {
39+
decode_[i.second] = i.first;
40+
}
41+
}
42+
43+
std::vector<int64_t> encode(const std::string& prompt) {
44+
std::vector<std::string> words = parse_prompt(prompt);
45+
std::vector<int64_t> result;
46+
for (auto word : words) {
47+
result.push_back(encode_[word]);
48+
}
49+
return result;
50+
}
51+
52+
std::string decode(const std::vector<int64_t>& indices) {
53+
std::string result;
54+
for (const auto& index : indices) {
55+
result += decode_[index];
56+
}
57+
return result;
58+
}
59+
60+
private:
61+
std::unordered_map<std::string, int64_t> encode_;
62+
std::unordered_map<int64_t, std::string> decode_;
63+
64+
// Advance the input string index until a non-whitespace character is found
65+
// or it reaches the end of string.
66+
size_t consume_whitespace(const std::string& data, size_t i) {
67+
while (i < data.size() && std::isspace(data[i])) {
68+
i++;
69+
}
70+
71+
return i;
72+
}
73+
74+
// Consumes an JSON field of the form
75+
// "str": id,
76+
size_t consume_field(const std::string& data, size_t i) {
77+
i = consume_whitespace(data, i);
78+
79+
// Parse the key literal.
80+
i = expect(data, i, '"');
81+
82+
auto in_escape = false;
83+
std::string key = "";
84+
while (i < data.size()) {
85+
if (in_escape) {
86+
key += data[i];
87+
i++;
88+
in_escape = false;
89+
} else { // !in_escape
90+
if (data[i] == '"') { // End of string literal
91+
i++;
92+
break;
93+
} else if (data[i] == '\\') { // Escaped code point
94+
in_escape = true;
95+
}
96+
key += data[i];
97+
i++;
98+
}
99+
}
100+
101+
key = post_process_key(key);
102+
103+
i = expect(data, i, ':');
104+
i = consume_whitespace(data, i);
105+
106+
// Read unsigned integer value
107+
auto value_start = i;
108+
while (i < data.size() && std::isdigit(data[i])) {
109+
i++;
110+
}
111+
auto value = static_cast<int64_t>(
112+
std::stol(data.substr(value_start, i - value_start)));
113+
114+
encode_[key] = value;
115+
116+
i = consume_whitespace(data, i);
117+
if (i < data.size() && data[i] == ',') {
118+
i++;
119+
}
120+
121+
return i;
122+
}
123+
124+
// Assert that the next character in the input string is equal to c. Increment
125+
// the input string index by one.
126+
size_t expect(const std::string& data, size_t i, char c) {
127+
if (i >= data.size() || data[i] != c) {
128+
std::cerr << "Invalid tokenizer vocabulary file. Expected '" << c
129+
<< "' at index " << i << std::endl;
130+
exit(1);
131+
}
132+
133+
return i + 1;
134+
}
135+
136+
std::string post_process_key(std::string key) {
137+
// Replace the unicode characters with the corresponding byte encoding
138+
// TODO: adopt byte encoder to handle unicode characters in json file.
139+
140+
std::unordered_map<std::string, std::string> replacements = {
141+
{"\\u0120", " "},
142+
{"\\u010a", "\n"},
143+
};
144+
145+
for (const auto& replacement : replacements) {
146+
size_t pos = 0;
147+
// While loop through all instances of the substring in the string
148+
while ((pos = key.find(replacement.first, pos)) != std::string::npos) {
149+
key.replace(pos, replacement.first.length(), replacement.second);
150+
pos += replacement.second.length();
151+
}
152+
}
153+
154+
// remove duplicate backslashes
155+
for (size_t idx = 0; idx < key.length(); idx++) {
156+
if (key[idx] == '\\') {
157+
key.erase(idx, 1);
158+
if (key[idx] == '\\') {
159+
// If there are two backslashes, keep the second one
160+
idx += 1;
161+
}
162+
}
163+
}
164+
165+
return key;
166+
}
167+
std::vector<std::string> parse_prompt(const std::string& prompt) {
168+
std::vector<std::string> result;
169+
std::string word;
170+
for (char c : prompt) {
171+
if (c == ' ') {
172+
if (!word.empty()) {
173+
result.push_back(word);
174+
word.clear();
175+
}
176+
word += c;
177+
} else if (ispunct(c)) {
178+
if (!word.empty()) {
179+
result.push_back(word);
180+
word.clear();
181+
}
182+
result.push_back(std::string(1, c));
183+
} else {
184+
word += c;
185+
}
186+
}
187+
if (!word.empty()) {
188+
result.push_back(word);
189+
}
190+
return result;
191+
}
192+
};

examples/llm_manual/export_nanogpt.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# export_nanogpt.py
8+
9+
# Load partitioner for Xnnpack backend
10+
import torch
11+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
12+
13+
# Model to be delegated to specific backend should use specific edge compile config
14+
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
15+
from executorch.exir import to_edge
16+
17+
from model import GPT
18+
from torch._export import capture_pre_autograd_graph
19+
from torch.export import export
20+
from torch.nn.attention import sdpa_kernel, SDPBackend
21+
22+
model = GPT.from_pretrained("gpt2") # use gpt2 weight as pretrained weight
23+
example_inputs = (
24+
torch.randint(0, 100, (1, model.config.block_size), dtype=torch.long),
25+
)
26+
dynamic_shape = ({1: torch.export.Dim("token_dim", max=model.config.block_size)},)
27+
28+
# Trace the model, converting it to a portable intermediate representation.
29+
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
30+
with sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
31+
m = capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shape)
32+
traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)
33+
34+
# Convert the model into a runnable ExecuTorch program.
35+
# To be further lowered to Xnnpack backend, `traced_model` needs xnnpack-specific edge compile config
36+
edge_config = get_xnnpack_edge_compile_config()
37+
edge_manager = to_edge(traced_model, compile_config=edge_config)
38+
39+
# Delegate exported model to Xnnpack backend by invoking `to_backend` function with Xnnpack partitioner.
40+
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
41+
et_program = edge_manager.to_executorch()
42+
43+
# Save the Xnnpack-delegated ExecuTorch program to a file.
44+
with open("nanogpt.pte", "wb") as file:
45+
file.write(et_program.buffer)

0 commit comments

Comments
 (0)