Skip to content

Commit fbbba34

Browse files
cccclaifacebook-github-bot
authored andcommitted
Factor out eager val from eval_llama_lib (#3756)
Summary: Pull Request resolved: #3756 Would like to re-use EagerEvalWrapper and eval function for quantization calibration. ghstack-source-id: 228123244 exported-using-ghexport Reviewed By: Jack-Khuu Differential Revision: D57881028 fbshipit-source-id: 85292401d184283381bc34d0f16329ca4b9632f3
1 parent 82663ac commit fbbba34

File tree

4 files changed

+147
-111
lines changed

4 files changed

+147
-111
lines changed

examples/models/llama2/builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
from typing import Any, Callable, List, Optional
1616

1717
import torch
18+
19+
try:
20+
from ...portable.utils import export_to_edge, save_pte_program
21+
except ImportError:
22+
# Workaround to bypass the different paths between executorch pip package and directly python call
23+
# TODO: remove this try catch workaround and have a standard wa to import portable.utils
24+
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `examples.portable.utils`.
25+
from examples.portable.utils import export_to_edge, save_pte_program
1826
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
1927
DuplicateDynamicQuantChainPass,
2028
)
@@ -33,7 +41,6 @@
3341
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
3442
from torch.nn.attention import SDPBackend
3543

36-
from ...portable.utils import export_to_edge, save_pte_program
3744
from ..model_factory import EagerModelFactory
3845

3946
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"

examples/models/llama2/eval_llama_lib.py

Lines changed: 2 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from typing import Optional, Union
1111

12-
import lm_eval
1312
import torch
13+
from executorch.examples.models.llama2.evaluate import EagerEvalWrapper, evaluate_model
1414
from executorch.examples.models.llama2.export_llama_lib import (
1515
get_quantizer_and_quant_params,
1616
)
@@ -20,11 +20,6 @@
2020
)
2121

2222
from lm_eval.api.model import LM
23-
from lm_eval.evaluator import evaluate
24-
from lm_eval.models.huggingface import HFLM as eval_wrapper
25-
from lm_eval.tasks import get_task_dict
26-
27-
from torch import nn
2823

2924
from .builder import LlamaEdgeManager
3025
from .export_llama_lib import (
@@ -33,75 +28,6 @@
3328
)
3429

3530

36-
class EagerEvalWrapper(eval_wrapper):
37-
"""
38-
A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
39-
"""
40-
41-
def __init__(
42-
self,
43-
model: nn.Module,
44-
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
45-
max_seq_length: Optional[int] = None,
46-
use_kv_cache: bool = False,
47-
):
48-
device = "cuda" if torch.cuda.is_available() else "cpu"
49-
super().__init__(device=device)
50-
self._model = model
51-
self._tokenizer = tokenizer
52-
self._device = torch.device(device)
53-
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
54-
self._use_kv_cache = use_kv_cache
55-
56-
@property
57-
def eot_token_id(self):
58-
return self._tokenizer.eos_id
59-
60-
@property
61-
def max_length(self):
62-
return self._max_seq_length
63-
64-
@property
65-
def max_gen_toks(self):
66-
return 50
67-
68-
@property
69-
def batch_size(self):
70-
return 1
71-
72-
@property
73-
def device(self):
74-
return self._device
75-
76-
def tok_encode(self, string: str, **kwargs):
77-
tokens = self._tokenizer.encode(string, bos=True, eos=False)
78-
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
79-
# encoded is a pytorch tensor, but some internal logic in the
80-
# eval harness expects it to be a list instead
81-
# TODO: verify this for multi-batch as well
82-
encoded = encoded.tolist()
83-
return encoded
84-
85-
def tok_decode(self, tokens):
86-
decoded = self._tokenizer.decode(tokens)
87-
return decoded
88-
89-
def _model_call(self, inps):
90-
if self._use_kv_cache:
91-
pos_tensor = torch.arange(
92-
self._max_seq_length, dtype=torch.int64, device=self.device
93-
)
94-
95-
# Batch process the whole sequence.
96-
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
97-
return logits
98-
else:
99-
return self._model(inps)
100-
101-
def _model_generate(self, context, max_length, eos_token_id):
102-
raise Exception("unimplemented")
103-
104-
10531
class ETPybindEvalWrapper(EagerEvalWrapper):
10632
"""
10733
A wrapper class for ExecuTorch py-binded integration with the
@@ -165,40 +91,6 @@ def _model_call(self, inps):
16591
pass
16692

16793

168-
@torch.no_grad()
169-
def eval(
170-
eval_wrapper: LM,
171-
tasks: Optional[list] = None,
172-
limit: Optional[int] = None,
173-
) -> dict:
174-
"""
175-
Evaluates a language model on a specified task using the lm-evaluation-harness library.
176-
177-
Args:
178-
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
179-
task (str): The name of the evaluation task to perform.
180-
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
181-
182-
Returns:
183-
eval_results (dict): A dictionary of evaluation results for the specified task(s).
184-
"""
185-
186-
if tasks is None:
187-
tasks = ["wikitext"]
188-
189-
if "hendrycks_test" in tasks:
190-
tasks.remove("hendrycks_test")
191-
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
192-
task_dict = get_task_dict(tasks)
193-
194-
eval_results = evaluate(
195-
eval_wrapper,
196-
task_dict,
197-
limit=limit,
198-
)
199-
return eval_results
200-
201-
20294
def gen_eval_wrapper(
20395
model_name: str,
20496
args: argparse.ArgumentParser,
@@ -307,7 +199,7 @@ def eval_llama(
307199
eval_wrapper = gen_eval_wrapper(model_name, args)
308200

309201
# Evaluate the model
310-
eval_results = eval(
202+
eval_results = evaluate_model(
311203
eval_wrapper,
312204
args.tasks,
313205
args.limit,
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .eager_eval import EagerEvalWrapper, evaluate_model
8+
9+
__all__ = [
10+
"evaluate_model",
11+
"EagerEvalWrapper",
12+
]
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import Optional, Union
9+
10+
import lm_eval
11+
import torch
12+
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken
13+
from executorch.examples.models.llama2.tokenizer.tokenizer import (
14+
Tokenizer as SentencePieceTokenizer,
15+
)
16+
17+
from lm_eval.api.model import LM
18+
from lm_eval.evaluator import evaluate
19+
from lm_eval.models.huggingface import HFLM as eval_wrapper
20+
from lm_eval.tasks import get_task_dict
21+
22+
from torch import nn
23+
24+
25+
class EagerEvalWrapper(eval_wrapper):
26+
"""
27+
A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
28+
"""
29+
30+
def __init__(
31+
self,
32+
model: nn.Module,
33+
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
34+
max_seq_length: Optional[int] = None,
35+
use_kv_cache: bool = False,
36+
):
37+
device = "cuda" if torch.cuda.is_available() else "cpu"
38+
super().__init__(device=device)
39+
self._model = model
40+
self._tokenizer = tokenizer
41+
self._device = torch.device(device)
42+
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
43+
self._use_kv_cache = use_kv_cache
44+
45+
@property
46+
def eot_token_id(self):
47+
return self._tokenizer.eos_id
48+
49+
@property
50+
def max_length(self):
51+
return self._max_seq_length
52+
53+
@property
54+
def max_gen_toks(self):
55+
return 50
56+
57+
@property
58+
def batch_size(self):
59+
return 1
60+
61+
@property
62+
def device(self):
63+
return self._device
64+
65+
def tok_encode(self, string: str, **kwargs):
66+
tokens = self._tokenizer.encode(string, bos=True, eos=False)
67+
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
68+
# encoded is a pytorch tensor, but some internal logic in the
69+
# eval harness expects it to be a list instead
70+
# TODO: verify this for multi-batch as well
71+
encoded = encoded.tolist()
72+
return encoded
73+
74+
def tok_decode(self, tokens):
75+
decoded = self._tokenizer.decode(tokens)
76+
return decoded
77+
78+
def _model_call(self, inps):
79+
if self._use_kv_cache:
80+
pos_tensor = torch.arange(
81+
self._max_seq_length, dtype=torch.int64, device=self.device
82+
)
83+
84+
# Batch process the whole sequence.
85+
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
86+
return logits
87+
else:
88+
return self._model(inps)
89+
90+
def _model_generate(self, context, max_length, eos_token_id):
91+
raise Exception("unimplemented")
92+
93+
94+
@torch.no_grad()
95+
def evaluate_model(
96+
eval_wrapper: LM,
97+
tasks: Optional[list] = None,
98+
limit: Optional[int] = None,
99+
) -> dict:
100+
"""
101+
Evaluates a language model on a specified task using the lm-evaluation-harness library.
102+
103+
Args:
104+
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
105+
task (str): The name of the evaluation task to perform.
106+
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
107+
108+
Returns:
109+
eval_results (dict): A dictionary of evaluation results for the specified task(s).
110+
"""
111+
112+
if tasks is None:
113+
tasks = ["wikitext"]
114+
115+
if "hendrycks_test" in tasks:
116+
tasks.remove("hendrycks_test")
117+
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
118+
task_dict = get_task_dict(tasks)
119+
120+
eval_results = evaluate(
121+
eval_wrapper,
122+
task_dict,
123+
limit=limit,
124+
)
125+
return eval_results

0 commit comments

Comments
 (0)