|
| 1 | +import math |
| 2 | +import pandas as pd |
| 3 | +import dataclasses |
| 4 | +from collections import defaultdict |
| 5 | +from absl import flags, app |
| 6 | + |
| 7 | +from typing import Dict |
| 8 | + |
| 9 | +FLAGS = flags.FLAGS |
| 10 | + |
| 11 | +flags.DEFINE_string('dataset_path', '', '') |
| 12 | + |
| 13 | +@dataclasses.dataclass |
| 14 | +class Stat: |
| 15 | + cache_size: int |
| 16 | + batch_size: int |
| 17 | + prefill_times: Dict[int, float] |
| 18 | + decode_time: float |
| 19 | + |
| 20 | +scenario1 = [ |
| 21 | + Stat( |
| 22 | + cache_size = 512, |
| 23 | + batch_size = 2048, |
| 24 | + prefill_times = { |
| 25 | + 16: 0.016024088603444397, |
| 26 | + 32: 0.021154335999926843, |
| 27 | + 64: 0.02999803279999469, |
| 28 | + 128: 0.043986773600045125, 256: 0.07524209819985117, 512: 0.13882793779994246}, |
| 29 | +decode_time = 0.28033976474989686 |
| 30 | + ), |
| 31 | + Stat( |
| 32 | + cache_size = 1280, |
| 33 | + batch_size = 512, |
| 34 | + prefill_times = { |
| 35 | + 16: 0.016024088603444397, |
| 36 | + 32: 0.020686019999993734, 64: 0.02952769919993443, 128: 0.04383329960000992, 256: 0.07538782240008005, 512: 0.13893127239989553, 1024: 0.2693996697998955}, |
| 37 | +decode_time=0.11505070800001249, |
| 38 | + ), |
| 39 | + Stat( |
| 40 | + cache_size = 3072, |
| 41 | + batch_size = 256, |
| 42 | + prefill_times = {32: 0.021193669800049976, 64: 0.030565194799964956, 128: 0.04334795760005363, 256: 0.07586566419995507, 512: 0.13899565000010625, 1024: 0.26945373279995694, 2048: 0.35605709000010394}, |
| 43 | + decode_time = 0.06467210225014242, |
| 44 | + ) |
| 45 | +] |
| 46 | + |
| 47 | +scenario2 = [ |
| 48 | + Stat( |
| 49 | + cache_size = 3072, |
| 50 | + batch_size = 256, |
| 51 | + prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882}, |
| 52 | + decode_time = 0.0631, |
| 53 | + ), |
| 54 | + Stat( |
| 55 | + cache_size = 3072, |
| 56 | + batch_size = 256, |
| 57 | + prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882}, |
| 58 | + decode_time = 0.0631, |
| 59 | + ), |
| 60 | + Stat( |
| 61 | + cache_size = 3072, |
| 62 | + batch_size = 256, |
| 63 | + prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882}, |
| 64 | + decode_time = 0.0631, |
| 65 | + ) |
| 66 | +] |
| 67 | +def eval_scenario(dataset, scenario): |
| 68 | + |
| 69 | + total_input_tokens = 0 |
| 70 | + total_output_tokens = 0 |
| 71 | + total_prefill_times = defaultdict(float) |
| 72 | + total_decode_times = defaultdict(float) |
| 73 | + output_tokens_by_bucket = defaultdict(int) |
| 74 | + for _, data in dataset.iterrows(): |
| 75 | + stat = scenario[data.bucket] |
| 76 | + total_input_tokens += data.tok_input_len |
| 77 | + total_output_tokens += data.tok_ref_output_len |
| 78 | + input_len_bucket = 2**math.ceil(math.log2(data.tok_input_len)) |
| 79 | + if input_len_bucket == 2048 and data.bucket == 1: |
| 80 | + import pdb; pdb.set_trace() |
| 81 | + total_prefill_times[input_len_bucket] += stat.prefill_times[input_len_bucket] |
| 82 | + output_tokens_by_bucket[data.bucket] += data.tok_ref_output_len |
| 83 | + |
| 84 | + for k in output_tokens_by_bucket.keys(): |
| 85 | + stat = scenario[k] |
| 86 | + total_decode_times[k] = output_tokens_by_bucket[k] / stat.batch_size * scenario[k].decode_time |
| 87 | + |
| 88 | + prefill_total = sum(total_prefill_times.values()) |
| 89 | + decode_total = sum(total_decode_times.values()) |
| 90 | + print('Total input tokens', total_input_tokens) |
| 91 | + print('Total output tokens', total_output_tokens) |
| 92 | + print('Input / output', total_input_tokens / total_output_tokens) |
| 93 | + print('Prefill times', total_prefill_times) |
| 94 | + print('pref throughput', total_input_tokens / sum(total_prefill_times.values())) |
| 95 | + print('decode times', total_decode_times) |
| 96 | + print('decode throughput', total_output_tokens / sum(total_decode_times.values()) ) |
| 97 | + print('overall throughput', |
| 98 | + total_output_tokens / |
| 99 | + (sum(total_decode_times.values()) + sum(total_prefill_times.values()))) |
| 100 | + print('prefill total time', prefill_total) |
| 101 | + print('decode total time', decode_total) |
| 102 | + |
| 103 | + |
| 104 | + |
| 105 | +def main(argv): |
| 106 | + dataset = pd.read_pickle(FLAGS.dataset_path) |
| 107 | + total_len = dataset.tok_input_len + dataset.tok_ref_output_len |
| 108 | + bucket = 0 + (total_len > 512) + ((total_len > 1280) | (dataset.tok_input_len > 1024)) |
| 109 | + dataset.insert(2, 'bucket', bucket) |
| 110 | + eval_scenario(dataset, scenario1) |
| 111 | + print('======== scenario 2 ========') |
| 112 | + eval_scenario(dataset, scenario2) |
| 113 | + |
| 114 | +if __name__ == '__main__': |
| 115 | + app.run(main) |
| 116 | + |
| 117 | + |
0 commit comments