Skip to content

Commit 8c9ad0e

Browse files
committed
add other scripts
1 parent 9e7db0d commit 8c9ad0e

File tree

7 files changed

+455
-2
lines changed

7 files changed

+455
-2
lines changed

mlperf/accuracy_run.sh

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/usr/bin/env bash
2+
me=$(basename "$0")
3+
4+
BASEDIR=mlperf
5+
API_URL=0.0.0.0:9000
6+
USER_CONFIG=$BASEDIR/user.conf
7+
DATA_DISK_DIR=$BASEDIR/data
8+
TOTAL_SAMPLE_COUNT=1000
9+
DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl
10+
11+
# HF model id
12+
TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1"
13+
LOADGEN_RUN_TYPE=offline-performance
14+
OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID}
15+
OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP}
16+
17+
mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR}
18+
19+
OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json
20+
21+
CACHE_LENGTH=1024
22+
INPUT_SIZE=512
23+
OUTPUT_SIZE=512
24+
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
25+
26+
LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
27+
# makes subsequent runs faster
28+
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2"
29+
export LIBTPU_INIT_ARGS
30+
31+
pushd ..
32+
# python -m mlperf.offline_mode \
33+
# --model_name=mixtral \
34+
# --max_cache_length=$CACHE_LENGTH \
35+
# --max_decode_length=$OUTPUT_SIZE \
36+
# --context_length=$INPUT_SIZE \
37+
# --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \
38+
# --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \
39+
# --quantize_weights=1 \
40+
# --quantize_type=int8_per_channel \
41+
# --quantize_kv_cache=1 \
42+
# --scenario Offline \
43+
# --input_mode tokenized \
44+
# --output_mode tokenized \
45+
# --mlperf_conf $BASEDIR/mlperf.conf \
46+
# --user_conf ${USER_CONFIG} \
47+
# --audit_conf no_audit \
48+
# --total_sample_count ${TOTAL_SAMPLE_COUNT} \
49+
# --dataset_path ${DATASET_PATH} \
50+
# --output_log_dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log
51+
52+
python -m mlperf.evaluate_accuracy \
53+
--checkpoint-path ${TOKENIZER_PATH} \
54+
--mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \
55+
--dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log
56+
popd

mlperf/evaluate_accuracy.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import argparse
2+
from transformers import AutoTokenizer
3+
import nltk
4+
import evaluate
5+
import numpy as np
6+
import pandas as pd
7+
import json
8+
import re
9+
10+
import logging
11+
logging.basicConfig(level=logging.DEBUG)
12+
log = logging.getLogger("evaluate_accuracy.py")
13+
14+
def get_args():
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument(
17+
"--checkpoint-path",
18+
required=True,
19+
help="Path to Mixtral-8x7b-Instruct checkpoint",
20+
)
21+
parser.add_argument(
22+
"--mlperf-accuracy-file",
23+
required=True,
24+
help="path to mlperf_log_accuracy.json",
25+
)
26+
parser.add_argument(
27+
"--dataset-file",
28+
required=True,
29+
help="path to processed validation dataset",
30+
)
31+
parser.add_argument(
32+
"--n_workers",
33+
default=2,
34+
type=int,
35+
help="Number of workers used for the MBXP evaluation",
36+
)
37+
parser.add_argument("--verbose", action="store_true", help="verbose messages")
38+
parser.add_argument(
39+
"--dtype",
40+
default="int64",
41+
help="dtype of the accuracy log",
42+
choices=["int32", "int64", "float"],
43+
)
44+
args = parser.parse_args()
45+
return args
46+
47+
48+
def get_groundtruth(processed_dataset_file):
49+
data = pd.read_pickle(processed_dataset_file)
50+
return data
51+
52+
53+
# Functions for evaluating GSM8K
54+
def find_numbers(x: str) -> list[str]:
55+
"""Finds all numbers in a string."""
56+
# Search for number, possibly negative (hyphen), with thousand separators
57+
# (comma), and with a decimal point (period inbetween digits).
58+
numbers = re.compile(
59+
r"-?[\d,]*\.?\d+",
60+
re.MULTILINE | re.DOTALL | re.IGNORECASE,
61+
).findall(x)
62+
return numbers
63+
64+
65+
def find_number(x: str, answer_delimiter: str = "The answer is") -> str:
66+
"""Finds the most relevant number in a string."""
67+
# If model uses the answer delimiter, then select the first number following
68+
# that format.
69+
if answer_delimiter in x:
70+
answer = x.split(answer_delimiter)[-1]
71+
numbers = find_numbers(answer)
72+
if numbers:
73+
return numbers[0]
74+
75+
# In general, select the last number in the string.
76+
numbers = find_numbers(x)
77+
if numbers:
78+
return numbers[-1]
79+
return ""
80+
81+
82+
def maybe_remove_comma(x: str) -> str:
83+
# Example: 5,600 -> 5600
84+
return x.replace(",", "")
85+
86+
87+
def try_float(x: str):
88+
try:
89+
ret = float(x)
90+
except BaseException:
91+
ret = None
92+
return ret
93+
94+
95+
# Functions for evaluating OpenOrca
96+
97+
98+
def postprocess_text(preds, targets):
99+
preds = [pred.strip() for pred in preds]
100+
targets = [target.strip() for target in targets]
101+
102+
# rougeLSum expects newline after each sentence
103+
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
104+
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]
105+
106+
return preds, targets
107+
108+
109+
# Functions for MBXP
110+
111+
112+
def create_mbxp_dict(row, response):
113+
lang, entry_point = row["id"].split("_", 1)
114+
return {
115+
"lang": lang,
116+
"prompt": row["input"],
117+
"test_code": row["gt_output"],
118+
"entry_point": entry_point,
119+
"response": response,
120+
}
121+
122+
123+
def main():
124+
125+
args = get_args()
126+
dataset_path = args.dataset_file
127+
checkpoint_path = args.checkpoint_path
128+
metric = evaluate.load("rouge")
129+
nltk.download("punkt")
130+
131+
tokenizer = AutoTokenizer.from_pretrained(
132+
checkpoint_path,
133+
model_max_length=2048,
134+
padding_side="left",
135+
use_fast=False,
136+
)
137+
138+
data = get_groundtruth(args.dataset_file)
139+
query_types, gt_outputs = data["dataset"], data["gt_output"]
140+
141+
target_required_GSM8K = []
142+
target_required_OpenOrca = []
143+
results_MBXP = []
144+
preds_token_GSM8K = []
145+
preds_token_OpenOrca = []
146+
preds_token_MBXP = []
147+
148+
eval_dtype = np.int64
149+
if args.dtype == "int32":
150+
eval_dtype = np.int32
151+
elif args.dtype == "float":
152+
eval_dtype = np.float32
153+
154+
with open(args.mlperf_accuracy_file, "r") as f:
155+
results = json.load(f)
156+
157+
seen = set()
158+
gen_tok_len = 0
159+
gen_num = 0
160+
for pred in results:
161+
gen_num += 1
162+
qsl_idx = pred["qsl_idx"]
163+
if qsl_idx in seen:
164+
continue
165+
166+
seen.add(qsl_idx)
167+
168+
query_type = query_types.iloc[qsl_idx]
169+
if query_type == "GSM8K":
170+
target = gt_outputs.iloc[qsl_idx]
171+
target_required_GSM8K.append(target)
172+
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype)
173+
gen_tok_len += len(pred)
174+
preds_token_GSM8K.append(pred)
175+
elif query_type == "OpenOrca":
176+
target = gt_outputs.iloc[qsl_idx]
177+
target_required_OpenOrca.append(target)
178+
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype)
179+
preds_token_OpenOrca.append(pred)
180+
gen_tok_len += len(pred)
181+
else:
182+
target = data.iloc[qsl_idx]
183+
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype)
184+
pred_str = tokenizer.decode(pred, skip_special_tokens=True)
185+
results_MBXP.append(create_mbxp_dict(target, pred_str))
186+
gen_tok_len += len(pred)
187+
188+
# OpenOrca metric
189+
preds_decoded_text = tokenizer.batch_decode(
190+
preds_token_OpenOrca, skip_special_tokens=True
191+
)
192+
193+
preds, targets = postprocess_text(
194+
preds_decoded_text, target_required_OpenOrca
195+
)
196+
197+
if preds:
198+
result = metric.compute(
199+
predictions=preds,
200+
references=targets,
201+
use_stemmer=True,
202+
use_aggregator=False,
203+
)
204+
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()}
205+
prediction_lens = [len(pred) for pred in preds]
206+
207+
else:
208+
result = {}
209+
prediction_lens = []
210+
211+
import ipdb; ipdb.set_trace()
212+
# GSM8K metric
213+
preds_decoded_text = tokenizer.batch_decode(
214+
preds_token_GSM8K, skip_special_tokens=True
215+
)
216+
pred_nums = [
217+
maybe_remove_comma(find_number(pred_text.split("\nQ:")[0]))
218+
for pred_text in preds_decoded_text
219+
]
220+
gsm8k_total = len(target_required_GSM8K)
221+
correct = 0
222+
for idx in range(len(target_required_GSM8K)):
223+
ref = try_float(target_required_GSM8K[idx])
224+
tgt = try_float(pred_nums[idx])
225+
if tgt is None:
226+
continue
227+
correct += ref == tgt
228+
229+
result["gsm8k"] = 100.0 * correct / gsm8k_total
230+
231+
# MBXP metric
232+
# from evaluate_mbxp import evaluate_mbxp
233+
234+
# if results_MBXP:
235+
# result['mbxp'] = evaluate_mbxp(results_MBXP, args.n_workers)
236+
# else:
237+
# result['mbxp'] = 0
238+
239+
result = {
240+
**result,
241+
"gen_len": np.sum(prediction_lens),
242+
"gen_num": gen_num,
243+
"gen_tok_len": gen_tok_len,
244+
"tokens_per_sample": round(gen_tok_len / gen_num, 1),
245+
}
246+
247+
print("\nResults\n")
248+
print(result)
249+
250+
251+
if __name__ == "__main__":
252+
main()

mlperf/install.sh

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/usr/bin/env bash
2+
3+
DATA_DISK_DIR=data
4+
5+
mkdir -p $DATA_DISK_DIR
6+
7+
pip install -U "huggingface_hub[cli]"
8+
pip install \
9+
transformers \
10+
nltk==3.8.1 \
11+
evaluate==0.4.0 \
12+
absl-py==1.4.0 \
13+
rouge-score==0.1.2 \
14+
sentencepiece==0.1.99 \
15+
accelerate==0.21.0
16+
17+
# install loadgen
18+
pip install mlperf-loadgen
19+
20+
21+
pushd $DATA_DISK_DIR
22+
23+
# model weights
24+
gcloud storage cp gs://sixiang_gcp/mixtral-instruct-quantized ./ --recursive
25+
# NOTE: uncomment one so you dont download too much weights to your box
26+
# gcloud storage cp gs://sixiang_gcp/llama2-70b/llama2-70b/ ./ --recursive
27+
28+
# Get mixtral data
29+
wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl
30+
mv mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl mixtral_15k_data.pkl
31+
wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl
32+
mv mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl mixtral_15k_calibration_data.pkl
33+
34+
# Get llama70b data
35+
gcloud storage cp \
36+
gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \
37+
processed-calibration-data.pkl
38+
gcloud storage cp \
39+
gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl \
40+
processed-data.pkl
41+
popd

mlperf/mixtral_run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ me=$(basename "$0")
44
BASEDIR=mlperf
55
USER_CONFIG=$BASEDIR/user.conf
66
DATA_DISK_DIR=$BASEDIR/data
7-
TOTAL_SAMPLE_COUNT=1000
7+
TOTAL_SAMPLE_COUNT=900
88

99
# HF model id
1010
TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1"

0 commit comments

Comments
 (0)