Skip to content

Commit b00be7f

Browse files
committed
commit act quant for conditional ffn
init params add other scripts debug accuracy
1 parent b7a2310 commit b00be7f

14 files changed

+1435
-38
lines changed

benchmarks/mixtral_offline.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ BATCH_SIZE=$2
33
INPUT_SIZE=1024
44
OUTPUT_SIZE=1024
55
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
6+
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2"
7+
export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization"
68

79
pushd ..
810
python -m benchmarks.run_offline \

benchmarks/offline_benchmark.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,48 +21,28 @@ class Stat:
2121
Stat(
2222
cache_size = 512,
2323
batch_size = 2048,
24-
prefill_times = {
25-
16: 0.016024088603444397,
26-
32: 0.021154335999926843,
27-
64: 0.02999803279999469,
28-
128: 0.043986773600045125, 256: 0.07524209819985117, 512: 0.13882793779994246},
29-
decode_time = 0.28033976474989686
24+
prefill_times = {16: 0.02084908019969589, 32: 0.024125573800120037, 64: 0.02697298339990084, 128: 0.03641403259971412, 256: 0.05809259879970341, 512: 0.10703752639965387},
25+
decode_time = 0.359
26+
#ecode_time = 0.28
3027
),
3128
Stat(
3229
cache_size = 1280,
3330
batch_size = 512,
34-
prefill_times = {
35-
16: 0.016024088603444397,
36-
32: 0.020686019999993734, 64: 0.02952769919993443, 128: 0.04383329960000992, 256: 0.07538782240008005, 512: 0.13893127239989553, 1024: 0.2693996697998955},
37-
decode_time=0.11505070800001249,
31+
prefill_times={16: 0.02070321020000847, 32: 0.02408570580009837, 64: 0.02650543759955326, 128: 0.036217428799864136, 256: 0.057748028799687746, 512: 0.10604073840004276, 1024: 0.20993155719988862},
32+
decode_time=0.094,
3833
),
3934
Stat(
4035
cache_size = 3072,
4136
batch_size = 256,
42-
prefill_times = {32: 0.021193669800049976, 64: 0.030565194799964956, 128: 0.04334795760005363, 256: 0.07586566419995507, 512: 0.13899565000010625, 1024: 0.26945373279995694, 2048: 0.35605709000010394},
43-
decode_time = 0.06467210225014242,
44-
)
37+
prefill_times={16: 0.020371186199918158, 32: 0.024281639599939807, 64: 0.02710893359981128, 128: 0.03605372060046648, 256: 0.0574128626001766, 512: 0.10610043820051943, 1024: 0.2097496903996216, 2048: 0.4301163775999157},
38+
decode_time = 0.0552,
39+
),
4540
]
4641

4742
scenario2 = [
48-
Stat(
49-
cache_size = 3072,
50-
batch_size = 256,
51-
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
52-
decode_time = 0.0631,
53-
),
54-
Stat(
55-
cache_size = 3072,
56-
batch_size = 256,
57-
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
58-
decode_time = 0.0631,
59-
),
60-
Stat(
61-
cache_size = 3072,
62-
batch_size = 256,
63-
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
64-
decode_time = 0.0631,
65-
)
43+
scenario1[2],
44+
scenario1[2],
45+
scenario1[2]
6646
]
6747
def eval_scenario(dataset, scenario):
6848

benchmarks/run_offline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def main(argv):
9494
decode_state = engine.init_decode_state()
9595
profiler_started = False
9696
for batch, _ in MAXTEXT_PREFILL.items():
97+
if batch > FLAGS.max_cache_length:
98+
continue
9799
runtime, decode_state, profiler_started = run_prefill_time(
98100
engine, params, decode_state, batch, profiler_started
99101
)

jetstream_pt/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,14 @@ def create_quantization_config_from_flags():
157157
return config
158158

159159

160-
def create_engine_from_config_flags():
160+
def create_engine_from_config_flags(batch=None, cache_len=None):
161161
"""create a pytorch engine from cmd flag"""
162162
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
163163
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
164164

165+
batch = batch or FLAGS.batch_size
166+
cache_len = cache_len or FLAGS.max_cache_length
167+
165168
devices = jax.devices()
166169
start = time.perf_counter()
167170

@@ -196,9 +199,9 @@ def create_engine_from_config_flags():
196199
bf16_enable=FLAGS.bf16_enable,
197200
param_size=FLAGS.size,
198201
context_length=FLAGS.context_length,
199-
batch_size=FLAGS.batch_size,
202+
batch_size=batch,
200203
quant_config=quant_config,
201-
max_cache_length=FLAGS.max_cache_length,
204+
max_cache_length=cache_len,
202205
max_decode_length=FLAGS.max_decode_length,
203206
sharding_config=sharding_file_name,
204207
shard_on_batch=FLAGS.shard_on_batch,

jetstream_pt/offline_inference.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from typing import Callable
2+
import dataclasses
3+
from collections import defaultdict
4+
import jax
5+
from jax import numpy as jnp
6+
import numpy as np
7+
8+
from jetstream.engine import engine_api
9+
10+
import logging
11+
12+
log = logging.getLogger(__name__)
13+
14+
15+
@dataclasses.dataclass
16+
class InputData:
17+
id: str
18+
tokens: jax.Array
19+
true_length: int
20+
21+
22+
class OfflineInference:
23+
24+
def __init__(self, engine: engine_api.Engine, params=None):
25+
self.engine = engine
26+
self.decode_state = None
27+
if params is None:
28+
params = engine.load_params()
29+
self.params = params
30+
31+
self.batch_size = engine.env.batch_size
32+
self.max_decode_length = engine.max_decode_length
33+
metadata = engine.get_tokenizer()
34+
self.tokenizer = engine.build_tokenizer(metadata)
35+
self.dummy = False
36+
37+
self._cached_pref = {}
38+
self._cached_generate = None
39+
40+
def init_decode_state(self):
41+
if self.decode_state is None:
42+
self.decode_state = self.engine.init_decode_state()
43+
44+
def warmup(self, max_length=2048):
45+
self.init_decode_state()
46+
interesting_buckets = [
47+
32,
48+
64,
49+
128,
50+
256,
51+
512,
52+
1024,
53+
2048,
54+
4096,
55+
]
56+
for length in interesting_buckets:
57+
if length > max_length:
58+
break
59+
log.info(f"Compiling prefill: {length}")
60+
input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32"))
61+
self._cached_pref[length] = (
62+
jax.jit(self._prefill_insert, donate_argnums=(4,))
63+
.lower(
64+
self.params,
65+
tokens=input_data,
66+
slot=0,
67+
true_length=length - 1,
68+
decode_state=self.decode_state)
69+
.compile()
70+
)
71+
72+
log.info(f"Compiling decode")
73+
self._cached_generate = (
74+
jax.jit(self.engine.generate, donate_argnums=(1,))
75+
.lower(self.params, self.decode_state)
76+
.compile()
77+
)
78+
79+
def _prefill_insert(self, params, tokens, slot, true_length, decode_state):
80+
"""return decodestate."""
81+
prefill_result, first_token = self.engine.prefill(
82+
params=params, padded_tokens=tokens, true_length=true_length
83+
)
84+
decode_state = self.engine.insert(prefill_result, decode_state, slot=slot)
85+
return first_token, decode_state
86+
87+
def batch_inference_with_callback(
88+
self,
89+
data: InputData,
90+
emit_first_token: Callable[[str, int], bool],
91+
emit_token: Callable[[str, int], bool],
92+
):
93+
"""callback is a function that takes id and token. It will be called once per output
94+
95+
token.
96+
"""
97+
98+
def prefill(slot, tokens, true_length):
99+
nonlocal self
100+
if self.dummy:
101+
log.debug("dummy prefill")
102+
return 123
103+
104+
prefill_fn = self._prefill_insert
105+
if (cached := self._cached_pref.get(len(tokens))) is not None:
106+
prefill_fn = cached
107+
108+
first_token, self.decode_state = prefill_fn(
109+
self.params, tokens=tokens, slot=slot,
110+
true_length=true_length, decode_state=self.decode_state
111+
)
112+
return first_token.data[0][0].item()
113+
114+
empty_slots = list(range(self.batch_size))
115+
slot_to_id = {}
116+
117+
dummy_length = 1
118+
119+
def decode():
120+
log.debug("decode")
121+
nonlocal self
122+
nonlocal slot_to_id
123+
nonlocal dummy_length
124+
if self.dummy:
125+
log.debug("Dummy generate")
126+
res = engine_api.ResultTokens(
127+
data=np.array([[123, 1, dummy_length]] * self.batch_size),
128+
tokens_idx=(0, 0),
129+
valid_idx=(0, 0),
130+
length_idx=(0, 0),
131+
samples_per_slot=(0, 0),
132+
)
133+
dummy_length += 1
134+
self.decode_state, result_tokens = self.decode_state, res
135+
else:
136+
gen_fn = self.engine.generate
137+
if self._cached_generate is not None:
138+
gen_fn = self._cached_generate
139+
self.decode_state, result_tokens = gen_fn(
140+
self.params, self.decode_state
141+
)
142+
143+
result_tokens = result_tokens.convert_to_numpy()
144+
145+
newly_empty = []
146+
for slot, id_ in slot_to_id.items():
147+
token, is_valid, length = result_tokens.data[slot]
148+
log.debug(f"slot is {slot}, length is {length}")
149+
should_finish = False
150+
if is_valid:
151+
should_finish = emit_token(id_, token.item())
152+
if should_finish or length >= self.max_decode_length:
153+
newly_empty.append(slot)
154+
155+
# Add slots of those that are empty to emtpy
156+
for slot in newly_empty:
157+
del slot_to_id[slot]
158+
empty_slots.append(slot)
159+
160+
for row in data:
161+
log.debug(f"empty_slots {len(empty_slots)}")
162+
while not empty_slots:
163+
# If slots are all full, decode until there are free slots
164+
# to insert
165+
decode()
166+
# do one insert
167+
log.debug(f"prefill {row.id}")
168+
slot = empty_slots.pop()
169+
first_token = prefill(slot, row.tokens, row.true_length)
170+
should_terminate = emit_first_token(row.id, first_token)
171+
if not should_terminate:
172+
slot_to_id[slot] = row.id
173+
else:
174+
empty_slots.append(slot) # dont use the slot
175+
176+
while slot_to_id:
177+
log.debug(f"slot to id {len(slot_to_id)}")
178+
decode()
179+
180+
def batch_inference(self, data: InputData):
181+
"""data is list of obj with id, tokens, and true length"""
182+
ans = defaultdict(list)
183+
184+
def callback(id_, token):
185+
nonlocal ans
186+
ans[id_].append(token)
187+
return token == self.tokenizer.eos_id
188+
189+
self.batch_inference_with_callback(
190+
data, emit_first_token=callback, emit_token=callback
191+
)
192+
return ans

jetstream_pt/third_party/mixtral/model.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
from torch.nn import functional as F
2323
from .config import ModelArgs, find_multiple
2424
from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer
25+
from jetstream_pt import quantize, torchjax
2526

2627
import jax
28+
import jax.numpy as jnp
2729

2830

2931
class Transformer(nn.Module):
@@ -233,6 +235,31 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
233235
else:
234236
return self.forward_for_short_seq_len(x, expert_indices)
235237

238+
def _int_ti_eoi_teo(self, lhs, rhs):
239+
# x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler)
240+
result = torchjax.call_jax(
241+
jax.lax.dot_general,
242+
lhs,
243+
rhs,
244+
(((1,), (2)), ((), ())),
245+
None,
246+
jnp.bfloat16.dtype,
247+
)
248+
return result
249+
250+
def _int_teo_eio_tei(self, lhs, rhs):
251+
#torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler
252+
result = torchjax.call_jax(
253+
jax.lax.dot_general,
254+
lhs,
255+
rhs,
256+
(((2,), (2,)), ((1, ), (0, ))),
257+
None,
258+
jnp.bfloat16.dtype,
259+
) # output is (eti) for some reason
260+
return result.transpose(0, 1)
261+
262+
236263
def forward_for_short_seq_len(
237264
self, x: Tensor, expert_indices: Tensor
238265
) -> Tensor:
@@ -260,14 +287,20 @@ def forward_for_long_seq_len(self, x, expert_indices):
260287
# o = config.imtermediate size
261288
# i = config.dim
262289
with jax.named_scope("conditional_ff"):
263-
x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler)
264-
x3 = torch.einsum("ti, eoi-> teo", x, self.w3) * self.w3_scaler
290+
x_int, x_scaler, _ = quantize.quantize_tensor(x, (1,))
291+
x_scaler = x_scaler.reshape(seqlen, 1, 1)
292+
293+
x1 = F.silu(self._int_ti_eoi_teo(x_int, self.w1) * self.w1_scaler * x_scaler)
294+
x3 = self._int_ti_eoi_teo(x_int, self.w3) * self.w3_scaler * x_scaler
295+
296+
x1x3_int, x1x3_scaler, _ = quantize.quantize_tensor(x1 * x3, (1, 2))
297+
x1x3_scaler = x1x3_scaler.reshape(seqlen, 1, 1)
265298
expert_outs = (
266-
torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler
299+
self._int_teo_eio_tei(x1x3_int, self.w2) * self.w2_scaler
267300
)
268301
# e = 8; need to reduce to 2
269302
seq_indexes = torch.arange(seqlen).unsqueeze(1)
270-
return expert_outs[seq_indexes, expert_indices]
303+
return expert_outs[seq_indexes, expert_indices] * x1x3_scaler
271304

272305

273306
class ConditionalFeedForward(nn.Module):

0 commit comments

Comments
 (0)