Skip to content

Commit 936b922

Browse files
committed
activation quant
1 parent e1a6068 commit 936b922

File tree

6 files changed

+49
-13
lines changed

6 files changed

+49
-13
lines changed

jetstream_pt/third_party/mixtral/model.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
from torch.nn import functional as F
2323
from .config import ModelArgs, find_multiple
2424
from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer
25+
from jetstream_pt import quantize, torchjax
2526

2627
import jax
28+
import jax.numpy as jnp
2729

2830

2931
class Transformer(nn.Module):
@@ -233,6 +235,31 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
233235
else:
234236
return self.forward_for_short_seq_len(x, expert_indices)
235237

238+
def _int_ti_eoi_teo(self, lhs, rhs):
239+
# x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler)
240+
result = torchjax.call_jax(
241+
jax.lax.dot_general,
242+
lhs,
243+
rhs,
244+
(((1,), (2)), ((), ())),
245+
None,
246+
jnp.bfloat16.dtype,
247+
)
248+
return result
249+
250+
def _int_teo_eio_tei(self, lhs, rhs):
251+
#torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler
252+
result = torchjax.call_jax(
253+
jax.lax.dot_general,
254+
lhs,
255+
rhs,
256+
(((2,), (2,)), ((1, ), (0, ))),
257+
None,
258+
jnp.bfloat16.dtype,
259+
) # output is (eti) for some reason
260+
return result.transpose(0, 1)
261+
262+
236263
def forward_for_short_seq_len(
237264
self, x: Tensor, expert_indices: Tensor
238265
) -> Tensor:
@@ -260,14 +287,20 @@ def forward_for_long_seq_len(self, x, expert_indices):
260287
# o = config.imtermediate size
261288
# i = config.dim
262289
with jax.named_scope("conditional_ff"):
263-
x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler)
264-
x3 = torch.einsum("ti, eoi-> teo", x, self.w3) * self.w3_scaler
290+
x_int, x_scaler, _ = quantize.quantize_tensor(x, (1,))
291+
x_scaler = x_scaler.reshape(seqlen, 1, 1)
292+
293+
x1 = F.silu(self._int_ti_eoi_teo(x_int, self.w1) * self.w1_scaler * x_scaler)
294+
x3 = self._int_ti_eoi_teo(x_int, self.w3) * self.w3_scaler * x_scaler
295+
296+
x1x3_int, x1x3_scaler, _ = quantize.quantize_tensor(x1 * x3, (1, 2))
297+
x1x3_scaler = x1x3_scaler.reshape(seqlen, 1, 1)
265298
expert_outs = (
266-
torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler
299+
self._int_teo_eio_tei(x1x3_int, self.w2) * self.w2_scaler
267300
)
268301
# e = 8; need to reduce to 2
269302
seq_indexes = torch.arange(seqlen).unsqueeze(1)
270-
return expert_outs[seq_indexes, expert_indices]
303+
return expert_outs[seq_indexes, expert_indices] * x1x3_scaler
271304

272305

273306
class ConditionalFeedForward(nn.Module):

mlperf/backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def __init__(
283283
self.dataset.LoadSamplesToRam,
284284
self.dataset.UnloadSamplesFromRam,
285285
)
286+
log.info(f'DATA set size: {self.dataset.total_sample_count} / {self.dataset.perf_count}')
286287
self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries)
287288

288289
def load_tokenizer(

mlperf/benchmark_run.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ BASEDIR=mlperf
22
API_URL=0.0.0.0:9000
33
USER_CONFIG=$BASEDIR/user.conf
44
DATA_DISK_DIR=$BASEDIR/data
5-
TOTAL_SAMPLE_COUNT=1000
5+
TOTAL_SAMPLE_COUNT=15000
66
DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl
77

88
# HF model id
@@ -29,4 +29,4 @@ python -m mlperf.main \
2929
--tokenizer-path ${TOKENIZER_PATH} \
3030
--log-interval 1000 \
3131
--output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log
32-
popd
32+
popd

mlperf/mlperf.conf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ gptj.Offline.min_query_count = 13368
8888
rnnt.Offline.min_query_count = 2513
8989
3d-unet.Offline.min_query_count = 43
9090
stable-diffusion-xl.Offline.min_query_count = 5000
91-
llama2-70b.Offline.min_query_count = 1000
91+
llama2-70b.Offline.min_query_count = 15000
9292
mixtral-8x7b.Offline.min_query_count = 1000
9393

9494
# These fields should be defined and overridden by user.conf.

mlperf/start_server.sh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
#!/usr/bin/env bash
22

33
CACHE_LENGTH=3072
4-
INPUT_SIZE=512
5-
OUTPUT_SIZE=512
6-
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
4+
INPUT_SIZE=2048
5+
OUTPUT_SIZE=1024
6+
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized
77

88
pushd ..
99
python run_server.py \
10+
--lazy_cache_update=1 \
11+
--ring_buffer=0 \
1012
--model_name=mixtral \
11-
--batch_size=128 \
13+
--batch_size=256 \
1214
--max_cache_length=$CACHE_LENGTH \
1315
--max_decode_length=$OUTPUT_SIZE \
1416
--context_length=$INPUT_SIZE \
@@ -17,4 +19,4 @@ python run_server.py \
1719
--quantize_weights=1 \
1820
--quantize_type=int8_per_channel \
1921
--quantize_kv_cache=1
20-
popd
22+
popd

mlperf/user.conf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
mixtral-8x7b.Server.target_qps = 1.8
2-
mixtral-8x7b.Offline.target_qps = 4.0
2+
mixtral-8x7b.Offline.target_qps = 20.0
33

0 commit comments

Comments
 (0)