Skip to content

Commit 28aae25

Browse files
author
Guang Yang
committed
Export to ExecuTorch: Initial Integration
1 parent 7e8d857 commit 28aae25

17 files changed

Lines changed: 723 additions & 4 deletions

File tree

optimum/commands/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414

1515
from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
1616
from .env import EnvironmentCommand
17-
from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand
17+
from .export import ExecuTorchExportCommand, ExportCommand, ONNXExportCommand, TFLiteExportCommand
1818
from .optimum_cli import optimum_cli_subcommand

optimum/commands/export/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414

1515

1616
from .base import ExportCommand
17+
from .executorch import ExecuTorchExportCommand
1718
from .onnx import ONNXExportCommand
1819
from .tflite import TFLiteExportCommand

optimum/commands/export/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""optimum.exporters command-line interface base classes."""
1616

1717
from .. import BaseOptimumCLICommand, CommandInfo
18+
from .executorch import ExecuTorchExportCommand
1819
from .onnx import ONNXExportCommand
1920
from .tflite import TFLiteExportCommand
2021

@@ -25,6 +26,11 @@ class ExportCommand(BaseOptimumCLICommand):
2526
help="Export PyTorch and TensorFlow models to several format.",
2627
)
2728
SUBCOMMANDS = (
29+
CommandInfo(
30+
name="executorch",
31+
help="Export PyTorch model to ExecuTorch.",
32+
subcommand_class=ExecuTorchExportCommand,
33+
),
2834
CommandInfo(
2935
name="onnx",
3036
help="Export PyTorch and TensorFlow to ONNX.",
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Defines the command line for the export with ExecuTorch."""
2+
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING
5+
6+
from ...exporters import TasksManager
7+
from ..base import BaseOptimumCLICommand
8+
9+
10+
if TYPE_CHECKING:
11+
from argparse import ArgumentParser
12+
13+
14+
def parse_args_executorch(parser):
15+
required_group = parser.add_argument_group("Required arguments")
16+
required_group.add_argument(
17+
"-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from."
18+
)
19+
required_group.add_argument(
20+
"--output_dir", type=Path, help="Path indicating the directory where to store the generated ExecuTorch model."
21+
)
22+
23+
optional_group = parser.add_argument_group("Optional arguments")
24+
optional_group.add_argument(
25+
"--task",
26+
default="auto",
27+
help=(
28+
"The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:"
29+
f" {str(TasksManager.get_all_tasks())}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder."
30+
),
31+
)
32+
optional_group.add_argument(
33+
"--recipe",
34+
type=str,
35+
default="xnnpack",
36+
help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".',
37+
)
38+
39+
40+
class ExecuTorchExportCommand(BaseOptimumCLICommand):
41+
@staticmethod
42+
def parse_args(parser: "ArgumentParser"):
43+
return parse_args_executorch(parser)
44+
45+
def run(self):
46+
from ...exporters.executorch import main_export
47+
48+
main_export(
49+
model_name_or_path=self.args.model,
50+
task=self.args.task,
51+
recipe=self.args.recipe,
52+
output_dir=self.args.output_dir,
53+
)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import TYPE_CHECKING
2+
from transformers.utils import _LazyModule
3+
4+
5+
_import_structure = {
6+
"modeling_executorch": [
7+
"ExecuTorchModelForCausalLM",
8+
],
9+
}
10+
11+
if TYPE_CHECKING:
12+
from .modeling_executorch import ExecuTorchModelForCausalLM
13+
else:
14+
import sys
15+
16+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""ExecuTorchModelForXXX classes, allowing to run ExecuTorch Models with ExecuTorch Runtime using the same API as Transformers."""
2+
3+
import logging
4+
import os
5+
import warnings
6+
from pathlib import Path
7+
from tempfile import TemporaryDirectory
8+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
9+
10+
import torch
11+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
12+
from huggingface_hub import hf_hub_download
13+
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
14+
from huggingface_hub.utils import EntryNotFoundError
15+
from transformers import (
16+
AutoConfig,
17+
AutoModel,
18+
GenerationMixin,
19+
AutoModelForCausalLM,
20+
GenerationConfig,
21+
PretrainedConfig,
22+
)
23+
from transformers.integrations.executorch import TorchExportableModuleWithStaticCache
24+
from transformers.modeling_outputs import (
25+
BaseModelOutput,
26+
CausalLMOutput,
27+
CausalLMOutputWithPast,
28+
ModelOutput,
29+
)
30+
31+
from ..exporters import TasksManager
32+
from ..exporters.executorch import main_export
33+
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
34+
35+
if TYPE_CHECKING:
36+
from transformers import PretrainedConfig
37+
38+
39+
logger = logging.getLogger(__name__)
40+
41+
42+
class ExecuTorchModelForCausalLM(OptimizedModel):
43+
"""
44+
ExecuTorch model with a causal language modeling head for ExecuTorch Runtime inference.
45+
"""
46+
47+
auto_model_class = AutoModelForCausalLM
48+
49+
def __init__(
50+
self,
51+
model: "ExecuTorchModule",
52+
config: "PretrainedConfig",
53+
):
54+
super().__init__(model, config)
55+
self.et_model = model
56+
logger.debug(f"Load all static methods: {self.et_model.method_names()}")
57+
self.use_kv_cache = self.et_model.run_method("use_kv_cache")[0]
58+
self.max_seq_len = self.et_model.run_method("get_max_seq_len")[0]
59+
self.max_batch_size = self.et_model.run_method("get_max_batch_size")[0]
60+
self.dtype = self.et_model.run_method("get_dtype")[0]
61+
self.bos_token_id = self.et_model.run_method("get_bos_id")[0]
62+
self.eos_token_id = self.et_model.run_method("get_eos_id")[0]
63+
self.vocab_size = self.et_model.run_method("get_vocab_size")[0]
64+
65+
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor) -> torch.Tensor:
66+
return self.et_model.forward((input_ids, cache_position))[0]
67+
68+
@classmethod
69+
def from_pretrained(
70+
cls,
71+
model_name_or_path: Union[str, Path],
72+
task: str,
73+
recipe: str,
74+
export: bool = False,
75+
config: "PretrainedConfig" = None,
76+
use_auth_token: Optional[Union[bool, str]] = None,
77+
token: Optional[Union[bool, str]] = None,
78+
revision: Optional[str] = None,
79+
force_download: bool = False,
80+
cache_dir: str = HUGGINGFACE_HUB_CACHE,
81+
subfolder: str = "",
82+
local_files_only: bool = False,
83+
) -> "ExecuTorchModelForCausalLM":
84+
if use_auth_token is not None:
85+
warnings.warn(
86+
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
87+
FutureWarning,
88+
)
89+
if token is not None:
90+
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
91+
token = use_auth_token
92+
93+
if export:
94+
return cls._export(
95+
model_id=model_name_or_path,
96+
task=task,
97+
recipe=recipe,
98+
config=config,
99+
)
100+
else:
101+
return cls._from_pretrained(
102+
model_dir_path=model_name_or_path,
103+
task=task,
104+
recipe=recipe,
105+
config=config,
106+
)
107+
108+
@classmethod
109+
def _from_pretrained(
110+
cls,
111+
model_dir_path: Union[str, Path],
112+
task: str,
113+
recipe: str,
114+
config: PretrainedConfig,
115+
use_auth_token: Optional[Union[bool, str]] = None,
116+
token: Optional[Union[bool, str]] = None,
117+
revision: Optional[str] = None,
118+
force_download: bool = False,
119+
cache_dir: str = HUGGINGFACE_HUB_CACHE,
120+
subfolder: str = "",
121+
local_files_only: bool = False,
122+
**kwargs,
123+
) -> "ExecuTorchModelForCausalLM":
124+
"""Load a pre-trained model from a local directory."""
125+
full_path = os.path.join(f"{model_dir_path}", "model.pte")
126+
model = _load_for_executorch(full_path)
127+
logging.debug(f"{model.method_meta('forward')}")
128+
return cls(
129+
model=model,
130+
config=config,
131+
)
132+
133+
def _save_pretrained(self, save_directory):
134+
"""
135+
Saves a model weights into a directory, so that it can be re-loaded using the
136+
[`from_pretrained`] class method.
137+
"""
138+
raise NotImplementedError
139+
140+
@classmethod
141+
def _export(
142+
cls,
143+
model_id: str,
144+
task: str,
145+
recipe: str,
146+
config: PretrainedConfig,
147+
use_auth_token: Optional[Union[bool, str]] = None,
148+
token: Optional[Union[bool, str]] = None,
149+
revision: Optional[str] = None,
150+
force_download: bool = False,
151+
cache_dir: str = HUGGINGFACE_HUB_CACHE,
152+
subfolder: str = "",
153+
local_files_only: bool = False,
154+
trust_remote_code: bool = False,
155+
):
156+
if use_auth_token is not None:
157+
warnings.warn(
158+
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
159+
FutureWarning,
160+
)
161+
if token is not None:
162+
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
163+
token = use_auth_token
164+
165+
save_dir = TemporaryDirectory()
166+
save_dir_path = Path(save_dir.name)
167+
168+
# Export to ExecuTorch and save the pte file to the temporary directory
169+
main_export(
170+
model_name_or_path=model_id,
171+
output_dir=save_dir_path,
172+
task=task,
173+
recipe=recipe,
174+
subfolder=subfolder,
175+
revision=revision,
176+
cache_dir=cache_dir,
177+
token=token,
178+
local_files_only=local_files_only,
179+
force_download=force_download,
180+
trust_remote_code=trust_remote_code,
181+
)
182+
183+
return cls._from_pretrained(
184+
model_dir_path=save_dir_path,
185+
task=task,
186+
recipe=recipe,
187+
config=config,
188+
use_auth_token=use_auth_token,
189+
subfolder=subfolder,
190+
revision=revision,
191+
cache_dir=cache_dir,
192+
token=token,
193+
local_files_only=local_files_only,
194+
force_download=force_download,
195+
)
196+
197+
def generate(
198+
self,
199+
prompt_tokens: List[int],
200+
echo: bool = False,
201+
pos_base: int = 0,
202+
max_seq_len: int = 256,
203+
) -> List[int]:
204+
self.device = torch.device("cpu")
205+
max_seq_len = min(max_seq_len, self.max_seq_len)
206+
generated_tokens = []
207+
208+
# prefill
209+
for i, prompt_token in enumerate(prompt_tokens):
210+
logits = self.forward(
211+
input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0),
212+
cache_position=torch.tensor([i], dtype=torch.long, device=self.device),
213+
)
214+
215+
next_token = torch.argmax(logits, dim=-1).item()
216+
generated_tokens = prompt_tokens + [next_token]
217+
218+
while len(generated_tokens) < max_seq_len:
219+
logits = self.forward(
220+
input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0),
221+
cache_position=torch.tensor(
222+
[pos_base + len(generated_tokens) - 1],
223+
dtype=torch.long,
224+
device=self.device,
225+
),
226+
)
227+
next_token = torch.argmax(logits, dim=-1).item()
228+
generated_tokens.append(next_token)
229+
if next_token == self.eos_token_id:
230+
break
231+
232+
return generated_tokens if echo else generated_tokens[len(prompt_tokens) :]
233+
234+
def text_generation(
235+
self,
236+
tokenizer: "PreTrainedTokenizer",
237+
prompt: str,
238+
max_seq_len: int = 256,
239+
echo: bool = True,
240+
):
241+
"""
242+
Perform text completion for a prompt using the language model.
243+
244+
Args:
245+
prompt (str): Text prompt for completion.
246+
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
247+
248+
Returns:
249+
Generated texts.
250+
251+
Note:
252+
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
253+
"""
254+
self.tokenizer = tokenizer
255+
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id:
256+
raise ValueError(
257+
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
258+
)
259+
if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id:
260+
raise ValueError(
261+
f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}."
262+
)
263+
264+
prompt_tokens = self.tokenizer.encode(prompt)
265+
generated_tokens = self.generate(
266+
prompt_tokens=prompt_tokens,
267+
echo=echo,
268+
max_seq_len=max_seq_len,
269+
)
270+
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

0 commit comments

Comments
 (0)