diff --git a/benchmarks/mixtral_offline.sh b/benchmarks/mixtral_offline.sh new file mode 100644 index 00000000..ea64195f --- /dev/null +++ b/benchmarks/mixtral_offline.sh @@ -0,0 +1,20 @@ +CACHE_LENGTH=1024 +INPUT_SIZE=512 +OUTPUT_SIZE=1024 +BATCH_SIZE=512 +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ + +pushd .. +python -m benchmarks.run_offline \ + --model_name=mixtral \ + --batch_size=$BATCH_SIZE \ + --max_cache_length=$CACHE_LENGTH \ + --max_decode_length=$OUTPUT_SIZE \ + --context_length=$INPUT_SIZE \ + --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ + --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ + --quantize_weights=1 \ + --quantize_type=int8_per_channel \ + --quantize_kv_cache=1 \ + --profiling_output=/mnt/disks/hanq/mixtral-profiles +popd \ No newline at end of file diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index daeafac7..1fdc0cb7 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -32,7 +32,7 @@ flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file") -def run_prefill_time(engine, params, decode_state, seqlen): +def run_prefill_time(engine, params, decode_state, seqlen, profiler_started): """Run prefill and measure time.""" metadata = engine.get_tokenizer() tokenizer = engine.build_tokenizer(metadata) @@ -53,6 +53,10 @@ def run_prefill_time(engine, params, decode_state, seqlen): nums = 5 start = time.perf_counter() for i in range(nums): + if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started: + jax.profiler.start_trace(FLAGS.profiling_output) + profiler_started = True + prefill_result, _ = engine.prefill( params=params, padded_tokens=tokens, true_length=true_length ) @@ -60,8 +64,9 @@ def run_prefill_time(engine, params, decode_state, seqlen): prefill_result, decode_state, slot=jnp.int32(i) ) jax.block_until_ready(decode_state) + end = time.perf_counter() - return (end - start) / nums, decode_state + return (end - start) / nums, decode_state, profiler_started MAXTEXT_PREFILL = { @@ -86,9 +91,10 @@ def main(argv): prefill_times = {} decode_state = engine.init_decode_state() + profiler_started = False for batch, _ in MAXTEXT_PREFILL.items(): - runtime, decode_state = run_prefill_time( - engine, params, decode_state, batch + runtime, decode_state, profiler_started = run_prefill_time( + engine, params, decode_state, batch, profiler_started ) prefill_times[batch] = runtime @@ -103,10 +109,12 @@ def main(argv): profiling_output = FLAGS.profiling_output print("======= decode starting ===") + dec_times = [] for i in range(10): - if profiling_output and i == 7: + if profiling_output and i == 7 and not profiler_started: jax.profiler.start_trace(profiling_output) + profiler_started = True start = time.perf_counter() # pylint: disable-next=all decode_state, sampled_tokens = engine.generate(params, decode_state) @@ -116,14 +124,24 @@ def main(argv): dec_times.append(end - start) print(i, "decode time", (end - start)) - if profiling_output: + if profiler_started: jax.profiler.stop_trace() print("prefill ", prefill_times) - print("decode", sum(dec_times) / 10) + avg_decode_times = sum(dec_times[2:]) / len(dec_times[2:]) + print("decode", avg_decode_times) prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()} - decode_time_ms = sum(dec_times) * 1000 / 10 / FLAGS.batch_size + decode_time_ms = sum(dec_times[2:]) * 1000 / 8 + + largest_prefill = max(prefill_times.items()) + print("MAX tokens:", FLAGS.batch_size / avg_decode_times) + + time2 = (FLAGS.batch_size * FLAGS.max_decode_length) / ( + FLAGS.batch_size * largest_prefill[1] + + FLAGS.max_decode_length * avg_decode_times + ) + print("MAX tokens 2:", time2) sharegpt_path = FLAGS.sharegpt_path if sharegpt_path: diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 96bb4233..6d571d2c 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -16,6 +16,7 @@ def ragged_flash_attention_kernel( + layer_ref, start_ref, end_ref, line_end_ref, @@ -105,30 +106,45 @@ def run(): @functools.partial( - jax.jit, static_argnames=["bk", "mask_value", "normalize_var"] + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "testing", + "quantized", + ], ) def ragged_mqa( q: jax.Array, k: jax.Array, v: jax.Array, + layer, start: jax.Array, end: jax.Array, - k_scaler: jax.Array | None = None, - v_scaler: jax.Array | None = None, ragged_batch_index=None, ragged_block_index=None, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, + testing: bool = False, + quantized: bool = False, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi query attention.""" with jax.named_scope("ragged_mqa"): - batch_size, num_heads, head_dim = q.shape - seq_len = k.shape[1] + batch_size, time, head_dim = q.shape + seq_len = k.shape[-2] + + stacked = False + if k.ndim == 5: + stacked = True def kv_index_map( b, i, + layer_ref, start_ref, end_ref, line_end_ref, @@ -136,11 +152,20 @@ def kv_index_map( ragged_block_index_ref, ): index = b * (seq_len // bk) + i + + if stacked: + return ( + layer_ref[0], + ragged_batch_index_ref[index], + ragged_block_index_ref[index], + 0, + ) return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 def q_index_map( b, i, + layer_ref, start_ref, end_ref, line_end_ref, @@ -148,17 +173,32 @@ def q_index_map( ragged_block_index_ref, ): index = b * (seq_len // bk) + i + if stacked: + return layer_ref[0], ragged_batch_index_ref[index], 0, 0 return ragged_batch_index_ref[index], 0, 0 - def scaler_index_map(b, i, *_): + def scaler_index_map(b, i, layer_ref, *_): + if stacked: + return layer_ref[0], b, 0, i return b, 0, i line_end = jnp.where(start < end, end, seq_len - 1) + if stacked: + q_bp = (None, None, time, head_dim) + kv_bp = (None, None, bk, head_dim) + ks_bp = (None, None, 1, bk) + else: + q_bp = (None, time, head_dim) + kv_bp = (None, bk, head_dim) + ks_bp = (None, 1, bk) + in_specs = [ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(q_index_map, q_bp), + pl.BlockSpec(kv_index_map, kv_bp), + pl.BlockSpec(kv_index_map, kv_bp), + pl.BlockSpec(scaler_index_map, ks_bp), + pl.BlockSpec(scaler_index_map, ks_bp), ] inputs = ( start, @@ -169,15 +209,9 @@ def scaler_index_map(b, i, *_): q, k, v, + k_scaler, + v_scaler, ) - quantized = False - if k_scaler is not None: - in_specs = in_specs + [ - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - ] - inputs = inputs + (k_scaler, v_scaler) - quantized = True out, m, l = pl.pallas_call( functools.partial( @@ -191,33 +225,241 @@ def scaler_index_map(b, i, *_): num_scalar_prefetch=5, in_specs=in_specs, out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), ], grid=(batch_size, seq_len // bk), ), compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, + interpret=testing, out_shape=[ q, - jax.ShapeDtypeStruct( - (batch_size, num_heads, head_dim), jnp.float32 - ), - jax.ShapeDtypeStruct( - (batch_size, num_heads, head_dim), jnp.float32 - ), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), ], )(*inputs) return out, (m[..., 0], l[..., 0]) +def ragged_mqa_kernel_reference( + layer_ref, + start_ref, + end_ref, + line_end_ref, + pre_b_ref, + pre_i_ref, + q_ref, + k_ref, + v_ref, + k_scaler_ref, + v_scaler_ref, + o_ref, + m_ref, + l_ref, + bk: int, + mask_value: float, + normalize_var: bool, + quantized: bool, +): + """Pallas kernel for ragged attention.""" + b, i = pl.program_id(0), pl.program_id(1) + del layer_ref + + @pl.when(i == 0) + def init(): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + # length = lengths_ref[b] + # Always start from 0, left aligned + length = end_ref[b] + + @pl.when(i * bk < length) + def run(): + q = q_ref[...].astype(jnp.float32) + k = k_ref[...].astype(jnp.float32) + v = v_ref[...].astype(jnp.float32) + m_prev, l_prev = m_ref[...], l_ref[...] + + qk = jax.lax.dot_general( + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 + ) + + if normalize_var: + qk = qk / math.sqrt(k.shape[-1]) # Align with meta llama + # Quantized + if quantized: + qk = qk * k_scaler_ref[...] + + mask = i * bk + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length + qk = qk + jnp.where(mask, 0.0, mask_value) + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + # Quantized + if quantized: + s_curr = s_curr * v_scaler_ref[...] + + o_curr_times_l_curr = jnp.dot(s_curr, v) + + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + ).astype(o_ref.dtype) + + @functools.partial( - jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"] + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "testing", + "quantized", + ], +) +def ragged_mqa_reference( + q: jax.Array, + k: jax.Array, + v: jax.Array, + layer, + start: jax.Array, + end: jax.Array, + ragged_batch_index=None, + ragged_block_index=None, + k_scaler: jax.Array = None, + v_scaler: jax.Array = None, + bk: int = 512, + mask_value: float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, + testing: bool = False, + quantized: bool = False, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi query attention.""" + batch_size, time, head_dim = q.shape + # assert end.shape == (batch_size,) + seq_len = k.shape[-2] + + stacked = False + if k.ndim == 4: + stacked = True + + def _compute_ragged_block_indices(b, i, lengths_ref): + length = lengths_ref[b] + not_done = i * bk < length + am_last_batch = b == batch_size - 1 + # if length < bk, then it's -1, should be 0? + last_good_block = jax.lax.div(length, bk) - 1 + + # if not done, then still work on b, otherwise next batch + b_next = jnp.where(not_done, b, jnp.where(am_last_batch, b, b + 1)) + # if not done, i next = i + # if done + # if last batch, previous good block + # if not last batch, i next = 0 + i_next = jnp.where( + not_done, i, jnp.where(am_last_batch, last_good_block, 0) + ) + return b_next, i_next + + def kv_index_map(b, i, layer_ref, start_ref, end_ref, *_): + b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) + if stacked: + return layer_ref[0], b_next, i_next, 0 + return b_next, i_next, 0 + + def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_): + b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) + if stacked: + return layer_ref[0], b_next, 0, i_next + return b_next, 0, i_next + + if stacked: + kv_bp = (None, None, bk, head_dim) + ks_bp = (None, None, 1, bk) + else: + kv_bp = (None, bk, head_dim) + ks_bp = (None, 1, bk) + + in_specs = [ + pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, time, head_dim)), # q + pl.BlockSpec(kv_index_map, kv_bp), # k + pl.BlockSpec(kv_index_map, kv_bp), # v + pl.BlockSpec(kv_scale_index_map, ks_bp), # k_scaler + pl.BlockSpec(kv_scale_index_map, ks_bp), # v_scaler + ] + + inputs = ( + jnp.array([layer]), + start, + end, + end, # line_end, not actually used + ragged_batch_index, + ragged_block_index, + q, + k, + v, + k_scaler, + v_scaler, + ) + + out, m, l = pl.pallas_call( + functools.partial( + ragged_mqa_kernel_reference, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=quantized, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=6, + in_specs=in_specs, + out_specs=[ + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + interpret=testing, + # debug=True, + compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, + out_shape=[ + q, + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + ], + )(*inputs) + return out, (m[..., 0], l[..., 0]) + + +@functools.partial( + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "q_shard_axis", + "kv_shard_axis", + "testing", + ], ) def ragged_mha( q: jax.Array, k: jax.Array, v: jax.Array, + layer, start: jax.Array, end: jax.Array, ragged_batch_index: jax.Array, @@ -227,7 +469,9 @@ def ragged_mha( bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, - shard_axis: int = 1, + q_shard_axis: int = 0, + kv_shard_axis: int = 0, + testing: bool = False, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi head attention. Args: @@ -251,35 +495,66 @@ def ragged_mha( softmax denominator ([batch_size, num_heads, compute_dim, 1]). """ mask_value = DEFAULT_MASK_VALUE + bk = min(bk, k.shape[-2]) + bq, hq, tq, dq = q.shape + hkv = k.shape[-3] + tk = k.shape[-2] + + assert k.shape[-1] == q.shape[-1] + assert k.shape[-4] == q.shape[-4] + + rep = hq // hkv + if rep > 1: + q = q.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq) + stacked = k.ndim == 5 + + replicated_in_axes = 7 if k_scaler is None: - replicated_in_axes = 4 - replicated_inputs = (ragged_batch_index, ragged_block_index) + quantized = False + if k.ndim == 5: + kv_scale_shape = (k.shape[0], bq, 1, tk) + else: + kv_scale_shape = (bq, 1, tk) + k_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16) + v_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16) else: - replicated_in_axes = 6 - replicated_inputs = ( - jnp.squeeze(k_scaler, -1), - jnp.squeeze(v_scaler, -1), - ragged_batch_index, - ragged_block_index, - ) + quantized = True + k_scale = jnp.squeeze(k_scaler, -1) + v_scale = jnp.squeeze(v_scaler, -1) + + if stacked: + assert k_scale.shape == (k.shape[0], bq, 1, tk) + else: + assert k_scale.shape == (bq, 1, tk) + + replicated_inputs = ( + ragged_batch_index, + ragged_block_index, + k_scale, + v_scale, + ) + # New cache has t=1 with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( functools.partial( - ragged_mqa, + # ragged_mqa, + ragged_mqa_reference, bk=bk, mask_value=mask_value, normalize_var=normalize_var, + testing=testing, + quantized=quantized, # out_dtype=out_dtype, ), in_axes=( - shard_axis, - shard_axis, - shard_axis, + q_shard_axis, + kv_shard_axis, + kv_shard_axis, *([None] * replicated_in_axes), ), - out_axes=shard_axis, - )(q, k, v, start, end, *replicated_inputs) + out_axes=q_shard_axis, + )(q, k, v, layer, start, end, *replicated_inputs) return out, (m, l) @@ -310,15 +585,77 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): return output +def flash_attention( + xq, + keys, + values, + layer, + k_scaler=None, + v_scaler=None, + mask=None, + normalize_var=True, +): + """Flash attention kernel.""" + if keys.ndim == 5: + keys = keys[layer] + values = values[layer] + k_scaler = k_scaler[layer] if k_scaler is not None else None + v_scaler = v_scaler[layer] if v_scaler is not None else None + + logits = torch.einsum( + "bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32) + ) + + if normalize_var: + logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama + # Quantized + if k_scaler is not None: + logits = logits * k_scaler.reshape( + k_scaler.shape[-4], 1, 1, k_scaler.shape[-2] + ) + + # mask = jnp.arange(keys.shape[1])[None] < lengths[:, None] + if mask is not None: + # logits = logits + jnp.where(mask, 0.0, DEFAULT_MASK_VALUE)[:, None] + logits = logits + mask + + logits_max, _ = torch.max(logits, axis=-1, keepdim=True) + unnormalized = torch.exp(logits - logits_max) + # Quantized, should not put here, otherwise sum will have this too, which cancels with denominator + # unnormalized = unnormalized * v_scaler + + denominator = unnormalized.sum(axis=-1, keepdim=True) + if v_scaler is not None: + unnormalized = unnormalized * v_scaler.reshape( + v_scaler.shape[-4], 1, 1, v_scaler.shape[-2] + ) + o = ( + torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values) + / denominator + ) + + return o, (logits_max, denominator) + + class RaggedAttentionKernel: """Ragged attention kernel.""" - def __init__(self, env, input_specs, output_specs, sharding_axis): + def __init__( + self, env, input_specs, output_specs, q_shard_axis, kv_shard_axis + ): self.binded_ragged_mha = functools.partial( - ragged_mha, bk=env.block_size, shard_axis=sharding_axis + ragged_mha, + bk=env.block_size, + q_shard_axis=q_shard_axis, + kv_shard_axis=kv_shard_axis, + testing=env.testing, ) self.binded_ragged_mha = shard_map( - ragged_mha, env.mesh, input_specs, output_specs, check_rep=False + self.binded_ragged_mha, + env.mesh, + input_specs, + output_specs, + check_rep=False, ) self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) @@ -327,6 +664,7 @@ def __call__( xq, keys, values, + layer, start, end, ragged_batch_index, @@ -338,6 +676,7 @@ def __call__( xq, keys, values, + layer, start, end, ragged_batch_index, diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 13789f91..76f44120 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -14,7 +14,10 @@ import jax import jax.numpy as jnp +from jax.experimental.shard_map import shard_map import torch +import torch_xla2 + from jetstream_pt import torchjax @@ -38,19 +41,20 @@ def update(self, key, value): class KVCachePrefill: """Prefill kv cache""" - def __init__(self, kv_quantize=False): + def __init__(self, kv_quantize=False, stacked=False): self.kv_quantize = kv_quantize self.cache_k = None self.cache_v = None + self.stacked = stacked - def update(self, key, value): + def update(self, key, value, layer_id): """This cache just remembers the stuff.""" self.cache_k = key self.cache_v = value if self.kv_quantize: # pretend to be quantized bsz, _, seq, _ = key.shape ones = torchjax.to_torch(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16)) - return key, value, ones, ones + return key, value, None, None, ones, ones, None, None return key, value @@ -58,6 +62,11 @@ def state(self): """Get prefill cache state""" return self.cache_k, self.cache_v + # Placeholder, to match with GenerateCache + def finalize(self): + """Finalize the cache operation and updates the cache.""" + return + # pylint: disable-next=all def KVCachePrefill_flatten(cache): @@ -80,57 +89,225 @@ def KVCachePrefill_unflatten(auxdata, data): ) -# Refactor out cache management -# Easier to test for quantized kv cache class KVCacheGenerate: """Kvache generator without quantization""" + # pylint: disable=too-many-instance-attributes + # More than 7 is reasonable in this case. def __init__( self, cache_k: torch.Tensor, # previous cache cache_v: torch.Tensor, # previous cache - position: int, # position to store the cache + position: int | torch.Tensor, # position to store the cache sharding, env=None, ): super().__init__() self.cache_k = cache_k self.cache_v = cache_v - self.pos = position + self.input_pos = position self.sharding = sharding self.env = env - def update(self, key, value): + self.new_ks = None + self.new_vs = None + self.env = env + # Keep this one it's used in the specific model code. + self.stacked = env.generate_cache_stacked + self.batch = jnp.arange(self.env.batch_size) + # The other way is to store the list and loop over to insert in finalize() + if self.env.lazy_cache_update: + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + layer, batch, heads, _, dim = self.cache_k.shape + new_dim = (layer, batch, heads, 1, dim) + self.new_ks, self.new_vs = torchjax.to_torch( + ( + jnp.zeros(new_dim, dtype=self.env.default_type), + jnp.zeros(new_dim, dtype=self.env.default_type), + ) + ) + else: + self.new_ks, self.new_vs = [], [] + else: # when generate cache is not stacked, new cache cannot stack + assert not self.env.new_cache_stacked + + cache_pspec = self.env.partition_by_axis( + self.env.cache_sharding_axis + ) # Number of heads + none_pspec = self.env.partition_by_axis() + in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec, none_pspec) + out_specs = (cache_pspec, cache_pspec) + self.update_single_cache_line = jax.jit( + shard_map( + self.update_single_cache_line, + self.env.mesh, + in_specs, + out_specs, + check_rep=False, + ) + ) + + # pylint: disable=method-hidden + # False alarm. The jit above doesn't hide this method. + def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, pos): + """The shard map version of single cache line update.""" + b = cache_k.shape[-4] + for bb, pp in enumerate(pos.reshape(b)): + slice_dim = 0 + update_start_indices = (bb, 0, pp, 0) + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + slice_dim = 1 + update_start_indices = (0, bb, 0, pp, 0) + # We are not handling generate_cache_stacked=True new_cache_stacked=False here + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + return cache_k, cache_v + + def finalize(self): + """Finalize the cache operation and updates the cache.""" + if not self.env.lazy_cache_update: + return + + if self.env.ring_buffer: + # Assume no cache stack for ring buffer + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax()) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax()) + ) + else: + if self.env.generate_cache_stacked: + _, b, head, _, dim = self.cache_k.shape + if self.env.new_cache_stacked: + self.cache_k, self.cache_v = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.input_pos, + ) + else: + for i in range(self.env.num_layers): + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax() + .at[i, self.batch, :, self.input_pos, :] + .set(self.new_ks[i].jax().reshape(b, head, dim)) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax() + .at[i, self.batch, :, self.input_pos, :] + .set(self.new_vs[i].jax().reshape(b, head, dim)) + ) + else: + # Try to use shard_map to get rid of the data copy + self.cache_k, self.cache_v = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.input_pos, + ) + + def update(self, key, value, layer_id: int): """Update kv cache""" keyj, valuej = torchjax.to_torch((key, value)) + if self.env.lazy_cache_update: + if self.env.new_cache_stacked: + assert ( + self.env.generate_cache_stacked + ), "When new cache stacked, must have generate_cache_stacked!" + self.new_ks[layer_id, ...] = keyj + self.new_vs[layer_id, ...] = valuej + return self.cache_k[layer_id], self.cache_v[layer_id] + + # Generate cache stacked, but new cache unstacked + if self.env.generate_cache_stacked: + self.new_ks.append(keyj) + self.new_vs.append(valuej) + return self.cache_k[layer_id], self.cache_v[layer_id] + + # all cache unstacked + self.new_ks = keyj + self.new_vs = valuej + return self.cache_k, self.cache_v + if self.env.ring_buffer: + assert ( + not self.env.new_cache_stacked and not self.env.generate_cache_stacked + ), "Ring buffer doesn't support stacked cache." # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj) + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(keyj) + ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej) - else: - batch = jnp.arange(self.env.batch_size) + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(valuej) + ) + return self.cache_k, self.cache_v + + # Non lazy cache update, non ring buffer, generate cache stacked + if self.env.generate_cache_stacked: # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos].set( - keyj.squeeze(2) + self.cache_k._elem = ( + self.cache_k.jax() + .at[layer_id, self.batch, :, self.input_pos, :] + .set(keyj.squeeze(2)) ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos].set( - valuej.squeeze(2) + self.cache_v._elem = ( + self.cache_v.jax() + .at[layer_id, self.batch, :, self.input_pos, :] + .set(valuej.squeeze(2)) ) + return self.cache_k[layer_id], self.cache_v[layer_id] + + # Non lazy cache update, non ring buffer, generate cache non stacked + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax() + .at[self.batch, :, self.input_pos, :] + .set(keyj.squeeze(2)) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax() + .at[self.batch, :, self.input_pos, :] + .set(valuej.squeeze(2)) + ) return self.cache_k, self.cache_v def state(self): """Get kv cache state""" - # pylint: disable-next=all return self.cache_k.jax(), self.cache_v.jax() @classmethod - def empty(cls, shape, device, bf16_enable, env): + def empty(cls, shape, device, env): """Create empty kv caches""" - default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32 - k = jnp.zeros(shape, device=device, dtype=default_dtype) - v = jnp.zeros(shape, device=device, dtype=default_dtype) + default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 + in_shape = shape + if env.testing: + key = jax.random.key(env.testing_seed) + k_key, v_key = jax.random.split(key) + k = jax.random.uniform(k_key, shape=in_shape, dtype=default_dtype) + v = jax.random.uniform(v_key, shape=in_shape, dtype=default_dtype) + else: + k = jnp.zeros(in_shape, device=device, dtype=default_dtype) + v = jnp.zeros(in_shape, device=device, dtype=default_dtype) k, v = torchjax.to_torch((k, v)) return cls(k, v, 0, device, env=env) @@ -159,7 +336,8 @@ def KVCacheGenerate_unflatten(auxdata, data): class Int8KVCacheGenerate: """Int8 quantized kvache with scalers""" - # pylint: disable-next=all + # pylint: disable=too-many-instance-attributes + # More than 7 is reasonable in this case. def __init__( self, cache_k, @@ -175,9 +353,153 @@ def __init__( self.cache_v = cache_v self.k_scaler = cache_k_scaler self.v_scaler = cache_v_scaler + self.new_ks = None + self.new_vs = None + self.new_k_scaler = None + self.new_v_scaler = None + + self.batch = jnp.arange(env.batch_size) self.input_pos = input_pos self.sharding = sharding self.env = env + self.stacked = env.generate_cache_stacked + + if self.env.lazy_cache_update: + if self.env.generate_cache_stacked: + layer, batch, heads, _, dim = self.cache_k.shape + new_kv_dim = (layer, batch, heads, 1, dim) + self.new_ks, self.new_vs = torchjax.to_torch( + ( + jnp.zeros(new_kv_dim, dtype=jnp.int8), + jnp.zeros(new_kv_dim, dtype=jnp.int8), + ) + ) + if self.env.new_cache_stacked: + new_scale_dim = (layer, batch, 1, 1, 1) + self.new_k_scaler, self.new_v_scaler = torchjax.to_torch( + ( + jnp.zeros(new_scale_dim, dtype=self.env.default_type), + jnp.zeros(new_scale_dim, dtype=self.env.default_type), + ) + ) + else: + self.new_ks, self.new_vs, self.new_k_scaler, self.new_v_scaler = ( + [], + [], + [], + [], + ) + else: # when generate cache is not stacked, new cache cannot stack + assert not self.env.new_cache_stacked + + cache_pspec = self.env.partition_by_axis( + self.env.cache_sharding_axis + ) # Number of heads + new_cache_pspec = ( + self.env.partition_by_axis(2) + if self.env.new_cache_stacked + else self.env.partition_by_axis(1) + ) + none_pspec = self.env.partition_by_axis() + in_specs = ( + *([cache_pspec] * 2), + *([new_cache_pspec] * 2), + *([none_pspec] * 5), + ) + out_specs = (cache_pspec, cache_pspec, none_pspec, none_pspec) + self.update_single_cache_line = shard_map( + self.update_single_cache_line, + self.env.mesh, + in_specs, + out_specs, + check_rep=False, + ) + self.update_single_cache_line = jax.jit(self.update_single_cache_line) + + # pylint: disable=method-hidden + # False alarm. The jit above doesn't hide this method. + def update_single_cache_line( + self, + cache_k, + cache_v, + new_ks, + new_vs, + k_scaler, + v_scaler, + new_k_scaler, + new_v_scaler, + pos, + ): + """The shard map version of single cache line update.""" + b = cache_k.shape[-4] + + for bb, pp in enumerate(pos.reshape(b)): + slice_dim = 0 + update_start_indices = (bb, 0, pp, 0) + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + slice_dim = 1 + update_start_indices = (0, bb, 0, pp, 0) + if self.env.generate_cache_stacked and not self.env.new_cache_stacked: + for layer in range(self.env.num_layers): + update_start_indices = (layer, bb, 0, pp, 0) + new_ks_slice = jax.lax.dynamic_slice_in_dim( + new_ks[layer], bb, 1, slice_dim + ) + new_ks_slice = jnp.expand_dims(new_ks_slice, 0) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + + new_vs_slice = jax.lax.dynamic_slice_in_dim( + new_vs[layer], bb, 1, slice_dim + ) + new_vs_slice = jnp.expand_dims(new_vs_slice, 0) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_k_scaler[layer], bb, 1, slice_dim + ) + new_k_scaler_slice = jnp.expand_dims(new_k_scaler_slice, 0) + k_scaler = jax.lax.dynamic_update_slice( + k_scaler, new_k_scaler_slice, update_start_indices + ) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_v_scaler[layer], bb, 1, slice_dim + ) + new_v_scaler_slice = jnp.expand_dims(new_v_scaler_slice, 0) + v_scaler = jax.lax.dynamic_update_slice( + v_scaler, new_v_scaler_slice, update_start_indices + ) + else: + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_k_scaler, bb, 1, slice_dim + ) + k_scaler = jax.lax.dynamic_update_slice( + k_scaler, new_k_scaler_slice, update_start_indices + ) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_v_scaler, bb, 1, slice_dim + ) + v_scaler = jax.lax.dynamic_update_slice( + v_scaler, new_v_scaler_slice, update_start_indices + ) + + return cache_k, cache_v, k_scaler, v_scaler def state(self): """Get kv cache state""" @@ -189,13 +511,17 @@ def scalers(self): @classmethod # pylint: disable-next=all - def empty(cls, shape, device, bf16_enable, env): + def empty(cls, shape, device, env): """Create empty kv caches""" cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8) cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8) - # bf16_enable is a placeholder parameter, it's not used in Int8KVCache - kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) - vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) + + if env.generate_cache_stacked: + s_shape = (shape[0], shape[1], 1, shape[3], 1) + else: + s_shape = (shape[0], 1, shape[2], 1) + kscaler = jnp.ones(s_shape, dtype=jnp.bfloat16) + vscaler = jnp.ones(s_shape, dtype=jnp.bfloat16) cache_k, cache_v, kscaler, vscaler = torchjax.to_torch( (cache_k, cache_v, kscaler, vscaler) @@ -205,23 +531,126 @@ def empty(cls, shape, device, bf16_enable, env): def quantize(self, val): """Quantize value""" # val is (batch, heads, seqlen, dim) - scale = torch.amax(val.abs(), axis=(1, 3), keepdim=True) + scale = torch.amax(val.abs(), axis=(-3, -1), keepdim=True) scale = scale / 127 return (val / scale).to(torch.int8), scale - def update(self, xk, xv): + def update(self, xk, xv, layer_id: int): """Update kv cache""" k_quant, kscale = self.quantize(xk) v_quant, vscale = self.quantize(xv) - if self.env.ring_buffer: + + if self.env.lazy_cache_update: + if self.env.new_cache_stacked: + self.new_ks[layer_id, ...] = k_quant + self.new_vs[layer_id, ...] = v_quant + self.new_k_scaler[layer_id, ...] = kscale + self.new_v_scaler[layer_id, ...] = vscale + else: + if self.env.generate_cache_stacked: + self.new_ks.append(k_quant) + self.new_vs.append(v_quant) + self.new_k_scaler.append(kscale) + self.new_v_scaler.append(vscale) + else: + self.new_ks = k_quant + self.new_vs = v_quant + self.new_k_scaler = kscale + self.new_v_scaler = vscale + elif self.env.ring_buffer: self.cache_k[:, :, self.input_pos, :] = k_quant self.cache_v[:, :, self.input_pos, :] = v_quant self.k_scaler[:, :, self.input_pos, :] = kscale self.v_scaler[:, :, self.input_pos, :] = vscale else: - batch = jnp.arange(self.env.batch_size) - self.cache_k[batch, :, self.input_pos, :] = k_quant.squeeze(2) - self.cache_v[batch, :, self.input_pos, :] = v_quant.squeeze(2) - self.k_scaler[batch, :, self.input_pos, :] = kscale.squeeze(2) - self.v_scaler[batch, :, self.input_pos, :] = vscale.squeeze(2) - return self.cache_k, self.cache_v, self.k_scaler, self.v_scaler + # We don't handle left aligned but lazy_cache_update=False + self.cache_k[self.batch, :, self.input_pos, :] = k_quant.squeeze(2) + self.cache_v[self.batch, :, self.input_pos, :] = v_quant.squeeze(2) + self.k_scaler[self.batch, :, self.input_pos, :] = kscale.squeeze(2) + self.v_scaler[self.batch, :, self.input_pos, :] = vscale.squeeze(2) + + return ( + self.cache_k, + self.cache_v, + k_quant, + v_quant, + self.k_scaler, + self.v_scaler, + kscale, + vscale, + ) + + def finalize(self): + """Finalize the cache operation and updates the cache.""" + if not self.env.lazy_cache_update: + return + if self.env.ring_buffer: + # Assume no cache stack for ring buffer + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax()) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax()) + ) + else: + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + # new kv scaler also has to go through shard_map instead of indexing + # because it needs to reshape to (batch, layer) which mess up with the data + caches = [ + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, + self.new_k_scaler, + self.new_v_scaler, + ] + ( + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, *caches, self.input_pos + ) + else: + caches = [ + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, + self.new_k_scaler, + self.new_v_scaler, + ] + ( + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, *caches, self.input_pos + ) + else: + ( + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, + self.new_k_scaler, + self.new_v_scaler, + self.input_pos, + ) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 78f8da9f..70b530fc 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -90,6 +90,31 @@ "Whether to enable ring buffer", required=False, ) +flags.DEFINE_bool( + "flash_attention", + False, + "Whether to enable flas attention. Only takes effect at test mode", + required=False, +) +flags.DEFINE_bool( + "generate_cache_stacked", + False, + "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", + required=False, +) +flags.DEFINE_bool( + "new_cache_stacked", + False, + "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", + required=False, +) +flags.DEFINE_bool( + "lazy_cache_update", + False, + "Whether to update the cache during attention or delayed until all the layers are done. " + "Only takes effect at test mode", + required=False, +) flags.DEFINE_float( "temperature", 1.0, @@ -184,6 +209,10 @@ def create_engine_from_config_flags(): nucleus_topp=FLAGS.nucleus_topp, topk=FLAGS.topk, ring_buffer=FLAGS.ring_buffer, + flash_attention=FLAGS.flash_attention, + generate_cache_stacked=FLAGS.generate_cache_stacked, + new_cache_stacked=FLAGS.new_cache_stacked, + lazy_cache_update=FLAGS.lazy_cache_update, ) print("Initialize engine", time.perf_counter() - start) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 0b27db3d..79dfb945 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -145,7 +145,7 @@ def init_decode_state( (self.env.batch_size, self.env.cache_sequence_length), float("-inf"), dtype=self.default_dtype, - ), + ), # mask ) # pylint: disable-next=all @@ -195,7 +195,10 @@ def _call_model_generate( # The mode is needed so that tensors created inside of # the model (such as via torch.ones etc) also have the right type res = torch.func.functional_call(self.pt_model, paramst, argst) - updated_caches = [c.state() for c in caches_obj] + updated_caches = [] + for c in caches_obj: + c.finalize() + updated_caches.append(c.state()) scales = [] if self.env.quant_config.enable_kv_quantization: scales = [c.scalers() for c in caches_obj] @@ -218,7 +221,8 @@ def _call_model_prefill(self, weights, tokens, input_indexes): dtype=self.default_dtype, ) mask = jnp.triu(mask, k=1) - args = (tokens, input_indexes, caches, mask) + start = jnp.zeros((tokens.shape[0],), dtype=jnp.int32) + args = (tokens, input_indexes, caches, mask, start) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: @@ -328,7 +332,7 @@ def _insert_no_wrap( tokens = decode_state.tokens.at[slot].set(prefix.token) x = jnp.arange(0, self.env.cache_sequence_length) - cond = jnp.logical_and(x <= current_pos, x >= pos) + cond = jnp.logical_and(x < current_pos, x >= pos) mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) start = decode_state.start.at[slot].set( @@ -338,31 +342,48 @@ def _insert_no_wrap( if not self.env.quant_config.enable_kv_quantization: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) - def insert(cache, new_entry): + def insert(cache, new_entry, update_index): res = jax.lax.dynamic_update_slice( cache, new_entry, - [slot, 0, pos, 0], + update_index, ) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) return res - caches = [ - (insert(k, newk), insert(v, newv)) - for (k, v), (newk, newv) in zip(decode_state.caches, prefix.caches) - ] + if self.env.generate_cache_stacked: + caches = decode_state.caches + for idx, (newk, newv) in enumerate(prefix.caches): + update_index = [idx, slot, 0, pos, 0] + newk = jnp.expand_dims(newk, 0) + newv = jnp.expand_dims(newv, 0) + caches = [ + ( + insert(caches[0][0], newk, update_index), + insert(caches[0][1], newv, update_index), + ) + ] + else: + update_index = [slot, 0, pos, 0] + caches = [ + (insert(k, newk, update_index), insert(v, newv, update_index)) + for (k, v), (newk, newv) in zip(decode_state.caches, prefix.caches) + ] else: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) - def insert(cache, scaler, new_entry): - reduce_axis = (1, 3) + def insert(cache, scaler, new_entry, update_index): + reduce_axis = (-3, -1) vals, scales, _ = torchjax.call_torch( quantize.quantize_tensor, new_entry, reduce_axis ) + if self.env.generate_cache_stacked: + vals = jnp.expand_dims(vals, 0) + scales = jnp.expand_dims(scales, 0) new_scaler = jax.lax.dynamic_update_slice( scaler, scales, - [slot, 0, pos, 0], + update_index, ) new_scaler = jax.lax.with_sharding_constraint( new_scaler, self.replicated @@ -370,19 +391,37 @@ def insert(cache, scaler, new_entry): res = jax.lax.dynamic_update_slice( cache, vals, - [slot, 0, pos, 0], + update_index, ) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) return res, new_scaler - for (k, v), (kscaler, vscaler), (newk, newv) in zip( - decode_state.caches, decode_state.cache_scales, prefix.caches - ): - kcache, kscale = insert(k, kscaler, newk) - vcache, vscale = insert(v, vscaler, newv) - caches.append((kcache, vcache)) - scales.append((kscale, vscale)) - + if self.env.generate_cache_stacked: + cache_k, k_scale = ( + decode_state.caches[0][0], + decode_state.cache_scales[0][0], + ) + cache_v, v_scale = ( + decode_state.caches[0][1], + decode_state.cache_scales[0][1], + ) + for idx, (newk, newv) in enumerate(prefix.caches): + update_index = [idx, slot, 0, pos, 0] + # newk = jnp.expand_dims(newk, 0) + # newv = jnp.expand_dims(newv, 0) + cache_k, k_scale = insert(cache_k, k_scale, newk, update_index) + cache_v, v_scale = insert(cache_v, v_scale, newv, update_index) + caches = [(cache_k, cache_v)] + scales = [(k_scale, v_scale)] + else: + update_index = [slot, 0, pos, 0] + for (k, v), (kscaler, vscaler), (newk, newv) in zip( + decode_state.caches, decode_state.cache_scales, prefix.caches + ): + kcache, kscale = insert(k, kscaler, newk, update_index) + vcache, vscale = insert(v, vscaler, newv, update_index) + caches.append((kcache, vcache)) + scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) return DecodeState( tokens, @@ -416,10 +455,10 @@ def _insert_wrap( cond = jax.lax.cond( decode_state.current_position > start_insert, lambda x, start_insert, current_position: jnp.logical_and( - x >= start_insert, x <= current_position + x >= start_insert, x < current_position ), lambda x, start_insert, current_position: jnp.logical_or( - x >= start_insert, x <= current_position + x >= start_insert, x < current_position ), x, start_insert, @@ -494,11 +533,6 @@ def insert( decode_state: DecodeState, slot: int, ) -> DecodeState: - # logging.info( - # 'Jet input prefix: %s, decode state before insert: %s', - # prefix, - # decode_state, - # ) if self.env.ring_buffer: start_insert = decode_state.current_position - prefix.seq_len end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen @@ -580,11 +614,9 @@ def generate( pos = decode_state.current_position if self.env.ring_buffer: input_indexes = jnp.full((1,), pos) - mask = decode_state.mask.at[:, decode_state.current_position].set(0) else: input_indexes = decode_state.input_pos - batch = jnp.arange(self.env.batch_size) - mask = decode_state.mask.at[batch, decode_state.input_pos].set(0) + ragged_batch_index, ragged_block_index = ( self.precompute_ragged_block_indices(decode_state) ) @@ -592,6 +624,16 @@ def generate( (-1) ), ragged_block_index.reshape((-1)) + def update_mask(): + if self.env.ring_buffer: + return decode_state.mask.at[:, decode_state.current_position].set(0) + + batch = jnp.arange(self.env.batch_size) + return decode_state.mask.at[batch, decode_state.input_pos].set(0) + + mask = decode_state.mask + if not self.env.lazy_cache_update: + mask = update_mask() logits, new_caches, new_scales = self._call_model_generate( params, decode_state.tokens, @@ -605,6 +647,10 @@ def generate( ragged_block_index, ) + if self.env.lazy_cache_update: + # fill mask later, now use flash attention + mask = update_mask() + next_token = self._sampling(logits, self.env.batch_size) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 @@ -648,11 +694,6 @@ def generate( input_pos, mask, ) - print( - "new_pos", - (decode_state.current_position + 1) % self.env.cache_sequence_length, - ) - print(f"new_token: {jnp.squeeze(next_token)}") return new_decode_state, result_tokens # pylint: disable-next=all @@ -832,6 +873,10 @@ def create_pytorch_engine( nucleus_topp=None, topk=None, ring_buffer=True, + flash_attention=False, + generate_cache_stacked=False, + new_cache_stacked=False, + lazy_cache_update=False, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -902,6 +947,10 @@ def create_pytorch_engine( nucleus_topp=nucleus_topp, topk=topk, ring_buffer=ring_buffer, + flash_attention=flash_attention, + generate_cache_stacked=generate_cache_stacked, + new_cache_stacked=new_cache_stacked, + lazy_cache_update=lazy_cache_update, ) if shard_on_batch and sharding_config: diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index fb1b99ba..84289d90 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -16,6 +16,7 @@ from typing import Tuple import jax +import jax.numpy as jnp import jax.sharding as jsharding from jax.experimental import mesh_utils import torch_xla2 @@ -91,11 +92,18 @@ class JetEngineEnvironmentData: block_size: int = 512 # Starting position - starting_position: int = 512 + starting_position: int = 0 # Ring buffer ring_buffer: bool = True + flash_attention: bool = False + + generate_cache_stacked: bool = False + + new_cache_stacked: bool = False + + lazy_cache_update: bool = False # Variables used in token sampling # sampling algorithm to use ("greedy", "weighted", "neucleus", "topk") sampling_algorithm: str = "greedy" @@ -109,6 +117,10 @@ class JetEngineEnvironmentData: # temperature parameter for scaling probability temperature: float = 1.0 + testing: bool = False + + testing_seed: int = 0 + # pylint: disable-next=all class JetEngineEnvironment: @@ -119,10 +131,34 @@ def __init__(self, data: JetEngineEnvironmentData): self.batch_size = self._data.batch_size self.seq_len = self._data.max_input_sequence_length self.cache_len = self._data.cache_sequence_length - self.ragged_mha = self._data.ragged_mha self.block_size = self._data.block_size self.starting_position = self._data.starting_position + self.num_layers = self._data.num_layers + self.testing = self._data.testing + self.testing_seed = self._data.testing_seed self.ring_buffer = self._data.ring_buffer + + if not self.ring_buffer: + self.lazy_cache_update = True + self.ragged_mha = True + self.flash_attention = True + self.generate_cache_stacked = True + self.new_cache_stacked = True + + if self.testing: + self.lazy_cache_update = self._data.lazy_cache_update + self.ragged_mha = self._data.ragged_mha + self.flash_attention = self._data.flash_attention + self.generate_cache_stacked = self._data.generate_cache_stacked + self.new_cache_stacked = self._data.new_cache_stacked + + self.default_type = jnp.bfloat16 if self._data.bf16_enable else jnp.float32 + + if self.generate_cache_stacked: + self.cache_shape = (self.num_layers, *self._data.cache_shape) + else: + self.cache_shape = self._data.cache_shape + P = jax.sharding.PartitionSpec num_of_partitions = jax.device_count() @@ -136,19 +172,29 @@ def __init__(self, data: JetEngineEnvironmentData): self.x_sharding = jsharding.NamedSharding(self.mesh, P("x")) self.replicated = jsharding.NamedSharding(self.mesh, P()) + if self.generate_cache_stacked: + self.attention_kv_axis_names = ( + "layer", + "batch", + "num_attn_heads", + "sequence_length", + "head_dim", + ) if data.shard_on_batch: - cache_sharding_axis = 0 + self.kv_cache_shard_axis = "batch" else: - cache_sharding_axis = self.attention_kv_axis_names.index( - self.kv_cache_shard_axis - ) + self.kv_cache_shard_axis = "num_attn_heads" - if self.cache_shape[cache_sharding_axis] == 1: + self.cache_sharding_axis = self.attention_kv_axis_names.index( + self.kv_cache_shard_axis + ) + + if self.cache_shape[self.cache_sharding_axis] == 1: # cannot shard on an axis that is 1 # default to last - cache_sharding_axis = len(self.cache_shape) - 1 + self.cache_sharding_axis = len(self.cache_shape) - 1 - self.cache_sharding = self.sharding_by_axis(cache_sharding_axis) + self.cache_sharding = self.sharding_by_axis(self.cache_sharding_axis) self._load_sharding_config() def _load_sharding_config(self): @@ -194,19 +240,20 @@ def make_caches_prefill(self): def make_caches_generate(self): """Create kv caches for inference generation""" caches = [] - shape = self._data.cache_shape - for _ in range(self.num_layers): + layered_cache_count = 1 if self.generate_cache_stacked else self.num_layers + + for _ in range(layered_cache_count): if self._data.quant_config.enable_kv_quantization: caches.append( cache_manager.Int8KVCacheGenerate.empty( - shape, self.cache_sharding, self.bf16_enable, env=self + self.cache_shape, self.cache_sharding, env=self ) ) else: caches.append( cache_manager.KVCacheGenerate.empty( - shape, self.cache_sharding, self.bf16_enable, env=self + self.cache_shape, self.cache_sharding, env=self ) ) return caches diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 3ecd875a..ed6ffff0 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -367,9 +367,19 @@ def apply_rotary_emb( def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep).""" - bs, n_kv_heads, slen, head_dim = x.shape + *_, bs, n_kv_heads, slen, head_dim = x.shape + stacked = x.ndim == 5 + if n_rep == 1: return x + + if stacked: + layer = x.shape[0] + return ( + x[:, :, :, None, :, :] + .expand(layer, bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(layer, bs, n_kv_heads * n_rep, slen, head_dim) + ) return ( x[:, :, None, :, :] .expand(bs, n_kv_heads, n_rep, slen, head_dim) @@ -379,18 +389,36 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class AttentionKernel: - def __init__(self, env): + def __init__(self, env, layer_id): self.env = env - self.shard_axis = 0 if self.env.shard_on_batch else 1 - qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + self.q_shard_axis = 0 if self.env.shard_on_batch else 1 + self.kv_shard_axis = ( + 0 + if self.env.shard_on_batch + else 2 + if self.env.generate_cache_stacked + else 1 + ) + q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads + kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention - self.ragged_attention = ak.RaggedAttentionKernel( + self.flash_attention = ak.flash_attention + self.ragged_attention_orig = ak.RaggedAttentionKernel( + env, + input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.kv_shard_axis, + ) + self.ragged_attention_new = ak.RaggedAttentionKernel( env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), - output_specs=(qkv_pspec, (others_pspec, others_pspec)), - sharding_axis=self.shard_axis, + input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.q_shard_axis, ) + self.layer_id = layer_id def __call__( self, @@ -413,53 +441,141 @@ def __call__( cache: CacheManagerInterface object """ bsz, num_heads, seqlen, head_dim = xq.shape - _, num_kv_heads, _, kv_head_dim = xk.shape + num_kv_heads = xk.shape[-3] + kv_head_dim = xk.shape[-1] n_rep = num_heads // num_kv_heads - if not self.env.ragged_mha and seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - with jax.named_scope("attn_insert_cache"): - keys, values = cache.update(xk, xv) - keys = repeat_kv(keys, n_rep) - values = repeat_kv(values, n_rep) + def attend(xq, keys, values, local_mask=None): + if keys.ndim == 4: + impl = self.ragged_attention_new + else: + impl = self.ragged_attention_orig + + true_len = seqlen + # When GQA is enabled, it not necessary to expand + if n_rep == 1 and seqlen == 1: + true_len = 2 + xq = torch.nn.functional.pad( + xq, (0, 0, 0, true_len - seqlen), "constant", 0 + ) - with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax( - self.ragged_attention, + local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( + impl, xq, keys, values, + self.layer_id, start, end, ragged_batch_index, ragged_block_index, ) + elif self.env.flash_attention: + with torch_xla2.default_env(): + local_output, (local_max, local_denom) = self.flash_attention( + xq, keys, values, self.layer_id, mask=local_mask + ) else: - output = self.dense_attention(xq, keys, values, None, None, mask) + local_output = self.dense_attention( + xq, keys, values, None, None, local_mask + ) + local_max = None + local_denom = None + + local_output = local_output.reshape(bsz, num_heads, true_len, head_dim) + if local_max is not None: + local_max = local_max.reshape(bsz, num_heads, true_len, 1) + local_denom = local_denom.reshape(bsz, num_heads, true_len, 1) + + if true_len != seqlen: + local_output = local_output[:, :, 0:seqlen, :] + if local_max is not None: + local_max = local_max[:, :, 0:seqlen, :] + if local_denom is not None: + local_denom = local_denom[:, :, 0:seqlen, :] + + # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") + # if local_max is not None and local_denom is not None: + # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") + self.env.apply_sharding(local_output, axis=self.q_shard_axis) + return local_output, (local_max, local_denom) + + with jax.named_scope("attn_insert_cache"): + orig_keys, orig_values = cache.update(xk, xv, self.layer_id) + # We are not using ragged attention for prefill yet. + if not self.env.ragged_mha or seqlen > 1: + orig_keys = repeat_kv(orig_keys, n_rep) + orig_values = repeat_kv(orig_values, n_rep) + + # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") + with jax.named_scope("attn_qkv"): + existing_output, (existing_max, existing_denom) = attend( + xq, orig_keys, orig_values, mask + ) + # Updating cache during each step still has very large impact on latency. + # For non flash attention or prefill, existing output contains everything + if not self.env.lazy_cache_update or seqlen > 1: + return existing_output + + # For flash attention, existing output contains the existing kv cache generated logits + with jax.named_scope("attn_new_qkv"): + if not self.env.ragged_mha or seqlen > 1: + xk = repeat_kv(xk, n_rep) + xv = repeat_kv(xv, n_rep) + new_output, (new_max, new_denom) = attend(xq, xk, xv, None) + + with jax.named_scope("attn_global"): + # print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") + # print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") + + global_sum = existing_denom * torch.exp( + existing_max + ) + new_denom * torch.exp(new_max) + existing_output = ( + existing_output + * existing_denom + * torch.exp(existing_max) + / global_sum + ) + new_output = new_output * new_denom * torch.exp(new_max) / global_sum + attn_out = existing_output + new_output - if not self.env.ragged_mha and seqlen == 1: - output = output[:, :, 0:1, :] - # For XLA matmul performance boost - # output = torch.matmul(scores, values) - self.env.apply_sharding(output, axis=self.shard_axis) - return output + return attn_out class Int8KVAttentionKernel: - def __init__(self, env): + def __init__(self, env, layer_id): self.env = env - self.shard_axis = 0 if self.env.shard_on_batch else 1 - qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + self.q_shard_axis = 0 if self.env.shard_on_batch else 1 + self.kv_shard_axis = ( + 0 + if self.env.shard_on_batch + else 2 + if self.env.generate_cache_stacked + else 1 + ) + q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads + kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention - self.ragged_attention = ak.RaggedAttentionKernel( + self.flash_attention = ak.flash_attention + self.ragged_attention_orig = ak.RaggedAttentionKernel( + env, + input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.kv_shard_axis, + ) + self.ragged_attention_new = ak.RaggedAttentionKernel( env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), - output_specs=(qkv_pspec, (others_pspec, others_pspec)), - sharding_axis=self.shard_axis, + input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.q_shard_axis, ) + self.layer_id = layer_id def __call__( self, @@ -482,24 +598,33 @@ def __call__( cache: CacheManagerInterface object """ bsz, num_heads, seqlen, head_dim = xq.shape - _, num_kv_heads, _, kv_head_dim = xk.shape + num_kv_heads = xk.shape[-3] + kv_head_dim = xk.shape[-1] n_rep = num_heads // num_kv_heads - if not self.env.ragged_mha and seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - - with jax.named_scope("attn_insert_cache"): - keys, values, k_scaler, v_scaler = cache.update(xk, xv) - keys = repeat_kv(keys, n_rep) - values = repeat_kv(values, n_rep) + def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): + if keys.ndim == 4: + impl = self.ragged_attention_new + else: + impl = self.ragged_attention_orig + + true_len = seqlen + # When GQA is enabled, it not necessary to expand + if n_rep == 1 and seqlen == 1: + true_len = 2 + xq = torch.nn.functional.pad( + xq, (0, 0, 0, true_len - seqlen), "constant", 0 + ) + # xq = torch.broadcast_to(xq, (bsz, num_heads, true_len, head_dim)) - with jax.named_scope("attn_qkv"): + # We are not using ragged attention for prefill yet. if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax( - self.ragged_attention, + local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( + impl, xq, keys, values, + self.layer_id, start, end, ragged_batch_index, @@ -507,22 +632,94 @@ def __call__( k_scaler, v_scaler, ) + elif self.env.flash_attention: + with torch_xla2.default_env(): + local_output, (local_max, local_denom) = self.flash_attention( + xq, + keys, + values, + self.layer_id, + k_scaler, + v_scaler, + mask=local_mask, + ) else: - output = self.dense_attention( - xq, keys, values, k_scaler, v_scaler, mask + local_output = self.dense_attention( + xq, keys, values, k_scaler, v_scaler, local_mask ) + local_max = None + local_denom = None - if not self.env.ragged_mha and seqlen == 1: - output = output[:, :, 0:1, :] + local_output = local_output.reshape(bsz, num_heads, true_len, head_dim) + if local_max is not None: + local_max = local_max.reshape(bsz, num_heads, true_len, 1) + local_denom = local_denom.reshape(bsz, num_heads, true_len, 1) - self.env.apply_sharding(output, axis=self.shard_axis) - return output + if true_len != seqlen: + local_output = local_output[:, :, 0:seqlen, :] + if local_max is not None: + local_max = local_max[:, :, 0:seqlen, :] + local_denom = local_denom[:, :, 0:seqlen, :] + + self.env.apply_sharding(local_output, axis=self.q_shard_axis) + return local_output, (local_max, local_denom) + + with jax.named_scope("attn_insert_cache"): + ( + orig_keys, + orig_values, + new_key, + new_value, + k_scaler, + v_scaler, + new_k_scaler, + new_v_scaler, + ) = cache.update(xk, xv, self.layer_id) + # We are not using ragged attention for prefill yet. + if not self.env.ragged_mha or seqlen > 1: + orig_keys = repeat_kv(orig_keys, n_rep) + orig_values = repeat_kv(orig_values, n_rep) + with jax.named_scope("attn_qkv"): + existing_output, (existing_max, existing_denom) = attend( + xq, orig_keys, orig_values, k_scaler, v_scaler, mask + ) + + # For non flash attention or prefill, existing output contains everything + if not self.env.lazy_cache_update or seqlen > 1: + return existing_output + + # For flash attention, existing output contains the existing kv cache generated logits + with jax.named_scope("attn_new_qkv"): + # At this point, flash attention or ragged attention must have been enabled + if not self.env.ragged_mha or seqlen > 1: + new_key = repeat_kv(new_key, n_rep) + new_value = repeat_kv(new_value, n_rep) + new_output, (new_max, new_denom) = attend( + xq, new_key, new_value, new_k_scaler, new_v_scaler, None + ) + + with jax.named_scope("attn_global"): + global_sum = existing_denom * torch.exp( + existing_max + ) + new_denom * torch.exp(new_max) + existing_output = ( + existing_output + * existing_denom + * torch.exp(existing_max) + / global_sum + ) + new_output = new_output * new_denom * torch.exp(new_max) / global_sum + attn_out = existing_output + new_output + + return attn_out class Attention(ModuleBase): """Attention module.""" - def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): + def __init__( + self, n_heads, n_kv_heads, head_dim, hidden_size, device, env, layer_id + ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads @@ -530,6 +727,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): self.n_rep = self.n_heads // self.n_kv_heads self.env = env self.hidden_size = hidden_size + self.layer_id = layer_id LinearLayer = get_quantized_linear_layer(env.quant_config) linear_kwargs = {} @@ -549,7 +747,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): if env.quant_config.enable_kv_quantization else AttentionKernel ) - self.attention_kernel = Kernel(env) + self.attention_kernel = Kernel(env, self.layer_id) self.q_size = n_heads * self.head_dim self.kv_size = self.n_kv_heads * self.head_dim @@ -629,16 +827,26 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) + if mask.ndim == 2: + if seqlen == 1: + mask = mask[:, None, None, :] + else: + mask = mask[None, None, :, :] + + # if cache is not None and cache.cache_k is not None: + # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( xq, xk, xv, mask, + # cache[self.layer_id], cache, start, end, ragged_batch_index, ragged_block_index, ).type_as(xq) + # print(f"output {output.shape}") output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) return self.wo(output) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 2b88055c..51f204ed 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -596,7 +596,7 @@ def insert(cache, new_entry): @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) def insert(cache, scaler, new_entry): - reduce_axis = (1, 3) + reduce_axis = (-3, -1) vals, scales, _ = torchjax.call_torch( quantize.quantize_tensor, new_entry, reduce_axis ) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 1072dad9..5773b8bd 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -73,6 +73,7 @@ def __init__( head_dim: int, device, env, + layer_id, ): super().__init__() @@ -135,7 +136,7 @@ def __init__( if env.quant_config.enable_kv_quantization else layers.AttentionKernel ) - self.attention_kernel = Kernel(env) + self.attention_kernel = Kernel(env, layer_id) def forward( self, @@ -272,7 +273,7 @@ def forward(self, x): class GemmaDecoderLayer(nn.Module): - def __init__(self, config: gemma_config.GemmaConfig, env): + def __init__(self, config: gemma_config.GemmaConfig, env, layer_id): super().__init__() self.self_attn = GemmaAttention( config.hidden_size, @@ -281,6 +282,7 @@ def __init__(self, config: gemma_config.GemmaConfig, env): config.head_dim, config.device, env, + layer_id, ) self.mlp = GemmaMLP( @@ -340,8 +342,8 @@ def __init__(self, config: gemma_config.GemmaConfig, env): self.env = env self.layers = nn.ModuleList() - for _ in range(config.num_hidden_layers): - self.layers.append(GemmaDecoderLayer(config, env)) + for layer_id in range(config.num_hidden_layers): + self.layers.append(GemmaDecoderLayer(config, env, layer_id)) self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, device=config.device ) diff --git a/jetstream_pt/third_party/llama/model_args.py b/jetstream_pt/third_party/llama/model_args.py index 7956667d..1b72c0a7 100755 --- a/jetstream_pt/third_party/llama/model_args.py +++ b/jetstream_pt/third_party/llama/model_args.py @@ -45,7 +45,8 @@ def get_arg( "dim": 128, "vocab_size": 32000, "multiple_of": 32, - "n_heads": 8, + "n_heads": 64, + "n_kv_heads": 8, "n_layers": 3, "norm_eps": 1e-05, } diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 7e700180..919eff35 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -103,6 +103,7 @@ def __init__( args.dim, env=env, device=args.device, + layer_id=layer_id, ) self.feed_forward = FeedForward( dim=args.dim, @@ -249,7 +250,6 @@ def forward( ragged_batch_index: precomputed batch index for ragged attention ragged_block_index: precomputed block index for ragged attention """ - with jax.named_scope("transformer_tok"): seqlen = tokens.shape[-1] h = self.tok_embeddings(tokens) @@ -259,12 +259,16 @@ def forward( freqs_cis = self.freqs_cis[input_pos] freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) - assert len(caches) == len( - self.layers - ), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match" end = None if start is None else (start + input_pos) % self.env.cache_len - for layer, cache in zip(self.layers, caches): - with jax.named_scope("TransformerBlock_Layer_" + str(layer.layer_id)): + # For stacked case, cannot get cache inside the loop which will cause cache copy + for layer_id, layer in enumerate(self.layers): + if caches[0].stacked: + cache = caches[0] + else: + cache = caches[layer_id] + # else: # For stacked case, there is only 1 yer of kv cache + + with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): h = layer( h, freqs_cis, diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py index b0d8d573..7396513d 100644 --- a/jetstream_pt/third_party/mixtral/model.py +++ b/jetstream_pt/third_party/mixtral/model.py @@ -38,7 +38,8 @@ def __init__(self, config: ModelArgs, env) -> None: config.vocab_size, config.dim, device=config.device ) self.layers = nn.ModuleList( - TransformerBlock(config, env) for _ in range(config.n_layer) + TransformerBlock(config, env, layer_id) + for layer_id in range(config.n_layer) ) self.norm = RMSNorm(config.dim, eps=config.norm_eps) LinearLayer = get_quantized_linear_layer(env.quant_config) @@ -142,7 +143,7 @@ def get_weight_sharding_type(): class TransformerBlock(nn.Module): - def __init__(self, config: ModelArgs, env) -> None: + def __init__(self, config: ModelArgs, env, layer_id) -> None: super().__init__() self.attention = Attention( config.n_head, @@ -151,6 +152,7 @@ def __init__(self, config: ModelArgs, env) -> None: config.dim, env=env, device=config.device, + layer_id=layer_id, ) self.block_sparse_moe = MOEFeedForward(config, config.device, env) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) diff --git a/mlperf/README.md b/mlperf/README.md new file mode 100644 index 00000000..b3322c1c --- /dev/null +++ b/mlperf/README.md @@ -0,0 +1,31 @@ +# Run MLPerf tests + +NOTE: currently only tried with mixtral; +and only tried with offline benchmark + +# How to run + +### 1. Install + +``` +./install.sh +``` + +### 2. Start server + +``` +./start_server.sh +``` + +### 3. Warm up the server + +``` +python warmup.py +``` + +### 4. Run the benchmark, now it runs offline mode + +``` +./benchmark_run.sh +``` + diff --git a/mlperf/backend.py b/mlperf/backend.py new file mode 100644 index 00000000..806eb727 --- /dev/null +++ b/mlperf/backend.py @@ -0,0 +1,352 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""mlperf loadgen interface for LLama2.""" +import math +import array +import concurrent.futures +import dataclasses +import json +import logging +from operator import itemgetter # pylint: disable=g-importing-member +import time +from typing import List, Optional, Any + +import numpy as np + +from . import dataset + +import mlperf_loadgen as lg + +import grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc + +from transformers import AutoTokenizer + + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("backend.py") + + +@dataclasses.dataclass +class WarmupSample: + id: int + index: int + + +@dataclasses.dataclass +class StreamResponse: + result: str = None + + +def _find_interesting_samples(max_length, dataset, encoder): + start_len = 1 + lengths = [len(encoder.encode(data)) for data in dataset] + min_length = min(lengths) + max_length = min(max(lengths), max_length) + start_len = 2 ** int(math.log2(min_length)) + while start_len * 2 < max_length: + for i, data in enumerate(dataset): + length = len(encoder.encode(data)) + if start_len < length <= start_len * 2: + log.info(f"Warmup sample: id={i} of length={length}") + yield i + break # for + else: + log.info( + f"DISCARD Warmup sample: id={i} of length={length} for {start_len}" + ) + start_len *= 2 + + +class ThreadedLMClient: + """Holds a thread pool and a loadgen client for LM inference.""" + + _thread_pool: concurrent.futures.ThreadPoolExecutor + _dataset: dataset.Dataset + _futures = List[concurrent.futures.Future] + + def __init__( + self, + is_stream: bool, + num_threads: int, + api_url: str, + dataset_object: dataset.Dataset, + input_mode: str, + output_mode: str, + tokenizer: Optional[AutoTokenizer] = None, + max_output_len: int = 1024, + log_interval: int = 1000, + ): + log.info(f"Initiating {self.__class__.__name__} ...") + self._is_stream = is_stream + self._input_mode = dataset.validate_sample_mode(input_mode) + self._output_mode = dataset.validate_sample_mode(output_mode) + if self._input_mode == "text" or self._output_mode == "text": + assert tokenizer is not None + self._tokenizer = tokenizer + self._max_output_len = max_output_len + + self._log_interval = log_interval + + self._thread_pool = concurrent.futures.ThreadPoolExecutor(num_threads) + self._api_url = api_url + self._dataset = dataset_object + self._futures = [] + self.pred_outputs = {} + self._resp_cnt = 0 + + # Post processing stop sequence for Mixtral MXBP dataset + self._stop_seq: List[int] = [13, 13940, 28832, 13] + self._stop_seq_len = len(self._stop_seq) + + log.info("Creating grpc channel with api_url {}".format(api_url)) + options = [("grpc.keepalive_timeout_ms", 10000)] + self._grpc_channel = grpc.insecure_channel(api_url, options=options) + + @property + def tokenizer(self): + return self._tokenizer + + def _log_resp_cnt(self): + self._resp_cnt += 1 + if self._resp_cnt % self._log_interval == 0: + log.info("Completed %d queries", self._resp_cnt) + + def post_process_response(self, response_tokens): + for i in range(self._stop_seq_len, len(response_tokens)): + if response_tokens[i - self._stop_seq_len : i] == self._stop_seq: + # log.info(f"Post process found stop seq: {response_tokens}") + return response_tokens[:i] + + # log.info(f"Post process no-op for {response_tokens}") + return response_tokens + + def process_single_sample_async(self, query_sample, warmup): + """Executes a single query and marks responses complete asynchronously. + + Args: + query_sample: Single prompt + warmup: Indicates that this is a warmup request. + """ + future = self._thread_pool.submit( + self._process_sample, query_sample, warmup + ) + self._futures.append(future) + + def flush(self): + concurrent.futures.wait(self._futures) + self._futures = [] + + def _grpc_request(self, request, sample, warmup): + """Send grpc synchronous request since the current grpc server is sync.""" + stub = jetstream_pb2_grpc.OrchestratorStub(self._grpc_channel) + token_list = [] + ttft = 0 + start_time = time.perf_counter() + response = stub.Decode(request) + for resp in response: + if not warmup and self._is_stream and ttft == 0: + # TTFT for online mode + ttft = time.perf_counter() - start_time + log.info("TTFT {}ms".format(ttft * 1000)) + response_token_ids = resp.stream_content.samples[0].token_ids + assert len(response_token_ids) == 1 + response_token_ids = np.array(response_token_ids, dtype=np.int64) + response_array = array.array("B", response_token_ids.tobytes()) + response_info = response_array.buffer_info() + first_token_response = lg.QuerySampleResponse( + sample.id, response_info[0], response_info[1] + ) + lg.FirstTokenComplete([first_token_response]) + log.info("mark first token complete") + token_list.extend(resp.stream_content.samples[0].token_ids) + return token_list + + def _process_sample(self, sample, warmup): + """Processes a single sample.""" + sample_data = self._dataset.inputs[sample.index] + if self._input_mode == "text": + token_ids = self._tokenizer.encode(sample_data) + else: + assert self._input_mode == "tokenized" + token_ids = [int(token_id_str) for token_id_str in sample_data.split(",")] + + request = jetstream_pb2.DecodeRequest( + session_cache="", + token_content=jetstream_pb2.DecodeRequest.TokenContent( + token_ids=token_ids + ), + priority=0, + max_tokens=self._max_output_len, + ) + generated_token_list = self._grpc_request(request, sample, warmup) + if not warmup: + try: + dataset_name = self._dataset.input_datasets[sample.index] + if dataset_name == "MBXP": + response_token_ids = self.post_process_response(generated_token_list) + else: + response_token_ids = generated_token_list + except Exception as e: + log.info(f"Error - {e}") + response_token_ids = generated_token_list + n_tokens = len(response_token_ids) + response_token_ids = np.array(response_token_ids, dtype=np.int64) + response_array = array.array("B", response_token_ids.tobytes()) + response_info = response_array.buffer_info() + response_data = response_info[0] + response_size = response_info[1] * response_array.itemsize + query_sample_response = lg.QuerySampleResponse( + sample.id, response_data, response_size, n_tokens + ) + lg.QuerySamplesComplete([query_sample_response]) + # log.info(f"mark query as complete for - {dataset_name}") + log.info(f"mark query as complete") + pred_output = self._tokenizer.decode(response_token_ids) + self.pred_outputs[sample.index] = pred_output + self._log_resp_cnt() + + +class SUT: + """SUT.""" + + def __init__( + self, + scenario, + api_url, + is_stream, + input_mode, + output_mode, + max_output_len, + dataset_path, + total_sample_count, + tokenizer_path=None, + perf_count_override=None, + num_client_threads=200, + log_interval=1000, + batch_size_exp=5, + pred_outputs_log_path=None, + ): + log.info(f"Starting {scenario} SUT with {api_url}.") + self._is_stream = is_stream + self._input_mode = dataset.validate_sample_mode(input_mode) + self._output_mode = dataset.validate_sample_mode(output_mode) + assert tokenizer_path is not None + self._tokenizer = self.load_tokenizer(tokenizer_path) + self._max_output_len = max_output_len + self._api_url = api_url + self._dataset_path = dataset_path + self._total_sample_count = total_sample_count + self._perf_count_override = perf_count_override + self._num_client_threads = num_client_threads + self._log_interval = log_interval + self._batch_size_exp = batch_size_exp + self._pred_outputs_log_path = pred_outputs_log_path + + log.info("Loading Dataset ... ") + self.dataset = dataset.Dataset( + dataset_path=self._dataset_path, + input_mode=self._input_mode, + total_sample_count=self._total_sample_count, + perf_count_override=self._perf_count_override, + ) + + client_cls = ThreadedLMClient + self._client = client_cls( + is_stream=self._is_stream, + num_threads=self._num_client_threads, + api_url=self._api_url, + dataset_object=self.dataset, + input_mode=self._input_mode, + output_mode=self._output_mode, + tokenizer=self._tokenizer, + max_output_len=self._max_output_len, + log_interval=self._log_interval, + ) + + self.qsl = lg.ConstructQSL( + self.dataset.total_sample_count, + self.dataset.perf_count, + self.dataset.LoadSamplesToRam, + self.dataset.UnloadSamplesFromRam, + ) + self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries) + + def load_tokenizer( + self, tokenizer_path: Optional[str] = None + ) -> Optional[AutoTokenizer]: + """Returns tokenizer""" + if tokenizer_path is not None: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + model_max_length=1024, + padding_side="left", + use_fast=True, + ) + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def _sort_issue_queries(self, query_samples): + """Issue queries.""" + query_samples_with_length = [] + for query_sample in query_samples: + query_sample_token_length = self.dataset.inputs_with_token_lengths[ + query_sample.index + ][1] + query_samples_with_length.append( + (query_sample_token_length, query_sample) + ) + sorted_query_samples_with_length = sorted( + query_samples_with_length, key=itemgetter(0) + ) + sorted_query_samples = [x[1] for x in sorted_query_samples_with_length] + return sorted_query_samples + + def issue_queries(self, query_samples): + """Issue queries.""" + num_query_samples = len(query_samples) + if num_query_samples > 1: + log.info(f"Issuing {num_query_samples} queries. ") + query_samples = self._sort_issue_queries(query_samples) + for query_sample in query_samples: + self._client.process_single_sample_async(query_sample, False) + + def flush_queries(self): + """Flush queries.""" + log.info("Loadgen has completed issuing queries... ") + self._client.flush() + + if self._pred_outputs_log_path is not None: + + pred_outputs = [] + for idx, x in self._client.pred_outputs.items(): + pred_output = { + "qsl_idx": idx, + "intput": self._client._dataset.inputs[idx], + "data": x, + } + pred_outputs.append(pred_output) + log.info(f"Generated {len(pred_outputs)} prediction outputs") + + if pred_outputs: + self.accuracy_log = open(self._pred_outputs_log_path, "w") + self.accuracy_log.write(json.dumps(pred_outputs)) + self.accuracy_log.flush() + self.accuracy_log.close() + log.info("Dumpped prediction outputs to accuracy log... ") + + def __del__(self): + print("Finished destroying SUT.") diff --git a/mlperf/benchmark_run.sh b/mlperf/benchmark_run.sh new file mode 100755 index 00000000..946c301a --- /dev/null +++ b/mlperf/benchmark_run.sh @@ -0,0 +1,32 @@ +BASEDIR=mlperf +API_URL=0.0.0.0:9000 +USER_CONFIG=$BASEDIR/user.conf +DATA_DISK_DIR=$BASEDIR/data +TOTAL_SAMPLE_COUNT=1000 +DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl + +# HF model id +TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1" + +LOADGEN_RUN_TYPE=offline-performance +OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} +OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} + +mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} + +pushd .. +python -m mlperf.main \ + --api-url ${API_URL} \ + --scenario Offline \ + --input-mode tokenized \ + --output-mode tokenized \ + --log-pred-outputs \ + --mlperf-conf $BASEDIR/mlperf.conf \ + --user-conf ${USER_CONFIG} \ + --audit-conf no-audit \ + --total-sample-count ${TOTAL_SAMPLE_COUNT} \ + --dataset-path ${DATASET_PATH} \ + --tokenizer-path ${TOKENIZER_PATH} \ + --log-interval 1000 \ + --output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log +popd \ No newline at end of file diff --git a/mlperf/dataset.py b/mlperf/dataset.py new file mode 100644 index 00000000..373bbc49 --- /dev/null +++ b/mlperf/dataset.py @@ -0,0 +1,128 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import pandas as pd + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("dataset.py") + + +class Dataset: + + def __init__( + self, + dataset_path: str, + input_mode: str, + total_sample_count: int = 15000, + perf_count_override: int = None, + ): + if not os.path.isfile(dataset_path): + log.warn( + "Processed pickle file {} not found. Please check that the path is correct".format( + dataset_path + ) + ) + self.dataset_path = dataset_path + + self._input_mode = validate_sample_mode(input_mode) + self.load_processed_dataset() + + self.total_sample_count = min(len(self.input_ids_strs), total_sample_count) + self.perf_count = perf_count_override or self.total_sample_count + + @property + def input_ids_strs(self): + return self._input_ids_strs + + @property + def input_texts(self): + return self._input_texts + + @property + def input_token_lengths(self): + return self._input_token_lengths + + @property + def inputs(self): + return self._inputs + + @property + def inputs_with_token_lengths(self): + return self._inputs_with_token_lengths + + @property + def input_datasets(self): + return self._input_datasets + + def load_processed_dataset(self): + processed_data = pd.read_pickle(self.dataset_path) + # processed_data = processed_data[processed_data["dataset"] == "MBXP"] + # processed_data = processed_data.reset_index(drop=True) + + self._input_ids_strs = [] + for input_ids in processed_data["tok_input"]: + input_ids_str = ",".join([str(input_id) for input_id in input_ids]) + self._input_ids_strs.append(input_ids_str) + + self._input_texts = [] + for input_text in processed_data["input"]: + self._input_texts.append(input_text) + + self._input_token_lengths = [] + for token_length in processed_data["tok_input_len"]: + self._input_token_lengths.append(token_length) + + log.info(f"input_mode is {self._input_mode}") + self._inputs = ( + self._input_ids_strs + if self._input_mode == "tokenized" + else self._input_texts + ) + log.info(f"example sample input is {self._inputs[0]}") + self._inputs_with_token_lengths = [ + (input_ids_str_or_input_text, token_length) + for input_ids_str_or_input_text, token_length in zip( + self._inputs, self._input_token_lengths + ) + ] + + self._input_datasets = [] + for dataset in processed_data["dataset"]: + self._input_datasets.append(dataset) + log.info( + f"example sample input dataset is {self._input_datasets[0]} and total {len(self._input_datasets)}" + ) + + def LoadSamplesToRam(self, sample_list): + pass + + def UnloadSamplesFromRam(self, sample_list): + pass + + def __del__(self): + pass + + +SAMPLE_MODE_CHOICES = ["tokenized", "text"] + + +def validate_sample_mode(sample_mode: str) -> str: + if sample_mode not in SAMPLE_MODE_CHOICES: + raise ValueError( + "The sample_mode should be set to either `tokenized` or `text`." + ) + return sample_mode diff --git a/mlperf/install.sh b/mlperf/install.sh new file mode 100644 index 00000000..3a8f037b --- /dev/null +++ b/mlperf/install.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +DATA_DISK_DIR=data + +mkdir -p $DATA_DISK_DIR + +pip install -U "huggingface_hub[cli]" +pip install \ + transformers \ + nltk==3.8.1 \ + evaluate==0.4.0 \ + absl-py==1.4.0 \ + rouge-score==0.1.2 \ + sentencepiece==0.1.99 \ + accelerate==0.21.0 + +# install loadgen +pip install mlperf-loadgen + + +pushd $DATA_DISK_DIR + +# model weights +gcloud storage cp gs://sixiang_gcp/mixtral-instruct-quantized ./ --recursive +# NOTE: uncomment one so you dont download too much weights to your box +# gcloud storage cp gs://sixiang_gcp/llama2-70b/llama2-70b/ ./ --recursive + +# Get mixtral data +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl +mv mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl mixtral_15k_data.pkl +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl +mv mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl mixtral_15k_calibration_data.pkl + +# Get llama70b data +gcloud storage cp \ + gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \ + processed-calibration-data.pkl +gcloud storage cp \ + gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl \ + processed-data.pkl +popd diff --git a/mlperf/main.py b/mlperf/main.py new file mode 100644 index 00000000..ad0fe7e2 --- /dev/null +++ b/mlperf/main.py @@ -0,0 +1,212 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import gc +import logging +import os +import sys + +from . import backend + +import mlperf_loadgen as lg + +_MLPERF_ID = "mixtral-8x7b" + +sys.path.insert(0, os.getcwd()) + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("main.py") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--scenario", + type=str, + choices=["Offline", "Server"], + default="Offline", + help="Scenario", + ) + parser.add_argument( + "--api-url", type=str, default=None, help="SAX published model path." + ) + parser.add_argument("--dataset-path", type=str, default=None, help="") + parser.add_argument("--tokenizer-path", type=str, default=None, help="") + parser.add_argument( + "--accuracy", action="store_true", help="Run accuracy mode" + ) + parser.add_argument("--is-stream", action="store_true", help="") + parser.add_argument( + "--input-mode", + type=str, + choices=["text", "tokenized"], + default="tokenized", + ) + parser.add_argument( + "--output-mode", + type=str, + choices=["text", "tokenized"], + default="tokenized", + ) + parser.add_argument( + "--max-output-len", type=int, default=1024, help="Maximum output len" + ) + parser.add_argument( + "--audit-conf", + type=str, + default="audit.conf", + help="audit config for LoadGen settings during compliance runs", + ) + parser.add_argument( + "--mlperf-conf", + type=str, + default="mlperf.conf", + help="mlperf rules config", + ) + parser.add_argument( + "--user-conf", + type=str, + default="user.conf", + help="user config for user LoadGen settings such as target QPS", + ) + parser.add_argument( + "--total-sample-count", + type=int, + default=15000, + help="Number of samples to use in benchmark.", + ) + parser.add_argument( + "--perf-count-override", + type=int, + default=None, + help="Overwrite number of samples to use in benchmark.", + ) + parser.add_argument( + "--output-log-dir", + type=str, + default="output-logs", + help="Where logs are saved.", + ) + parser.add_argument( + "--enable-log-trace", + action="store_true", + help="Enable log tracing. This file can become quite large", + ) + parser.add_argument( + "--num-client-threads", + type=int, + default=200, + help="Number of client threads to use", + ) + parser.add_argument("--batch-size-exp", type=int, default=6, help="") + parser.add_argument("--log-pred-outputs", action="store_true", help="") + parser.add_argument( + "--log-interval", + type=int, + default=1000, + help="Logging interval in seconds", + ) + parser.add_argument( + "--user-conf-override-path", + type=str, + default="", + help="When given overrides the default user.conf path", + ) + + args = parser.parse_args() + return args + + +scenario_map = { + "offline": lg.TestScenario.Offline, + "server": lg.TestScenario.Server, +} + + +def main(): + args = get_args() + + settings = lg.TestSettings() + settings.scenario = scenario_map[args.scenario.lower()] + if args.user_conf_override_path: + user_conf = args.user_conf_override_path + else: + user_conf = args.user_conf + + settings.FromConfig(args.mlperf_conf, _MLPERF_ID, args.scenario) + settings.FromConfig(user_conf, _MLPERF_ID, args.scenario) + log.info("Mlperf config: %s", args.mlperf_conf) + log.info("User config: %s", user_conf) + + if args.accuracy: + settings.mode = lg.TestMode.AccuracyOnly + log.warning( + "Accuracy run will generate the accuracy logs, but the evaluation of the log is not completed yet" + ) + else: + settings.mode = lg.TestMode.PerformanceOnly + settings.print_timestamps = True + + settings.use_token_latencies = True + + os.makedirs(args.output_log_dir, exist_ok=True) + log_output_settings = lg.LogOutputSettings() + log_output_settings.outdir = args.output_log_dir + log_output_settings.copy_summary_to_stdout = True + log_settings = lg.LogSettings() + log_settings.log_output = log_output_settings + log_settings.enable_trace = args.enable_log_trace + + sut = backend.SUT( + scenario=args.scenario.lower(), + api_url=args.api_url, + is_stream=args.is_stream, + input_mode=args.input_mode, + output_mode=args.output_mode, + max_output_len=args.max_output_len, + dataset_path=args.dataset_path, + total_sample_count=args.total_sample_count, + tokenizer_path=args.tokenizer_path, + perf_count_override=args.perf_count_override, + num_client_threads=args.num_client_threads, + log_interval=args.log_interval, + batch_size_exp=args.batch_size_exp, + pred_outputs_log_path=os.path.join( + args.output_log_dir, "pred_outputs_logger.json" + ) + if args.log_pred_outputs + else None, + ) + + lgSUT = sut.sut # lg.ConstructSUT(sut.issue_queries, sut.flush_queries) + log.info("Starting Benchmark run") + lg.StartTestWithLogSettings( + lgSUT, sut.qsl, settings, log_settings, args.audit_conf + ) + + log.info("Run Completed!") + + log.info("Destroying SUT...") + lg.DestroySUT(lgSUT) + + log.info("Destroying QSL...") + lg.DestroyQSL(sut.qsl) + + +if __name__ == "__main__": + # Disable garbage collection to avoid stalls when running tests. + gc.disable() + main() diff --git a/mlperf/mlperf.conf b/mlperf/mlperf.conf new file mode 100644 index 00000000..9400d0af --- /dev/null +++ b/mlperf/mlperf.conf @@ -0,0 +1,98 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds + +# Set performance_sample_count for each model. +# User can optionally set this to higher values in user.conf. +resnet50.*.performance_sample_count_override = 1024 +ssd-mobilenet.*.performance_sample_count_override = 256 +retinanet.*.performance_sample_count_override = 64 +bert.*.performance_sample_count_override = 10833 +dlrm.*.performance_sample_count_override = 204800 +dlrm-v2.*.performance_sample_count_override = 204800 +rnnt.*.performance_sample_count_override = 2513 +gptj.*.performance_sample_count_override = 13368 +llama2-70b.*.performance_sample_count_override = 24576 +stable-diffusion-xl.*.performance_sample_count_override = 5000 +# set to 0 to let entire sample set to be performance sample +3d-unet.*.performance_sample_count_override = 0 + +# Set seeds. The seeds will be distributed two weeks before the submission. +*.*.qsl_rng_seed = 3066443479025735752 +*.*.sample_index_rng_seed = 10688027786191513374 +*.*.schedule_rng_seed = 14962580496156340209 +# Set seeds for TEST_05. The seeds will be distributed two weeks before the submission. +*.*.test05_qsl_rng_seed = 16799458546791641818 +*.*.test05_sample_index_rng_seed = 5453809927556429288 +*.*.test05_schedule_rng_seed = 5435552105434836064 + + +*.SingleStream.target_latency_percentile = 90 +*.SingleStream.min_duration = 600000 + +*.MultiStream.target_latency_percentile = 99 +*.MultiStream.samples_per_query = 8 +*.MultiStream.min_duration = 600000 +*.MultiStream.min_query_count = 662 +retinanet.MultiStream.target_latency = 528 + +# 3D-UNet uses equal issue mode because it has non-uniform inputs +3d-unet.*.sample_concatenate_permutation = 1 + +# LLM benchmarks have non-uniform inputs and outputs, and use equal issue mode for all latency scenario +gptj.*.sample_concatenate_permutation = 1 +llama2-70b.*.sample_concatenate_permutation = 1 +mixtral-8x7B.*.sample_concatenate_permutation = 1 + +*.Server.target_latency = 10 +*.Server.target_latency_percentile = 99 +*.Server.target_duration = 0 +*.Server.min_duration = 600000 +resnet50.Server.target_latency = 15 +retinanet.Server.target_latency = 100 +bert.Server.target_latency = 130 +dlrm.Server.target_latency = 60 +dlrm-v2.Server.target_latency = 60 +rnnt.Server.target_latency = 1000 +gptj.Server.target_latency = 20000 +stable-diffusion-xl.Server.target_latency = 20000 +# Llama2-70b benchmarks measures token latencies +llama2-70b.*.use_token_latencies = 1 +mixtral-8x7b.*.use_token_latencies = 1 +# gptj benchmark infers token latencies +gptj.*.infer_token_latencies = 1 +gptj.*.token_latency_scaling_factor = 69 +# Only ttft and tpot are tracked for the llama2-70b & mixtral-8x7B benchmark therefore target_latency = 0 +llama2-70b.Server.target_latency = 0 +llama2-70b.Server.ttft_latency = 2000 +llama2-70b.Server.tpot_latency = 200 + +mixtral-8x7b.Server.target_latency = 0 +mixtral-8x7b.Server.ttft_latency = 2000 +mixtral-8x7b.Server.tpot_latency = 200 + +*.Offline.target_latency_percentile = 90 +*.Offline.min_duration = 600000 + +# In Offline scenario, we always have one query. But LoadGen maps this to +# min_sample_count internally in Offline scenario. If the dataset size is larger +# than 24576 we limit the min_query_count to 24576 and otherwise we use +# the dataset size as the limit + +resnet50.Offline.min_query_count = 24576 +retinanet.Offline.min_query_count = 24576 +dlrm-v2.Offline.min_query_count = 24576 +bert.Offline.min_query_count = 10833 +gptj.Offline.min_query_count = 13368 +rnnt.Offline.min_query_count = 2513 +3d-unet.Offline.min_query_count = 43 +stable-diffusion-xl.Offline.min_query_count = 5000 +llama2-70b.Offline.min_query_count = 1000 +mixtral-8x7b.Offline.min_query_count = 1000 + +# These fields should be defined and overridden by user.conf. +*.SingleStream.target_latency = 10 +*.MultiStream.target_latency = 80 +*.Server.target_qps = 1.0 +*.Offline.target_qps = 4.0 diff --git a/mlperf/start_server.sh b/mlperf/start_server.sh new file mode 100755 index 00000000..74f9d6b3 --- /dev/null +++ b/mlperf/start_server.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +CACHE_LENGTH=3072 +INPUT_SIZE=512 +OUTPUT_SIZE=512 +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ + +pushd .. +python run_server.py \ + --model_name=mixtral \ + --batch_size=128 \ + --max_cache_length=$CACHE_LENGTH \ + --max_decode_length=$OUTPUT_SIZE \ + --context_length=$INPUT_SIZE \ + --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ + --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ + --quantize_weights=1 \ + --quantize_type=int8_per_channel \ + --quantize_kv_cache=1 +popd \ No newline at end of file diff --git a/mlperf/user.conf b/mlperf/user.conf new file mode 100644 index 00000000..2b1fa841 --- /dev/null +++ b/mlperf/user.conf @@ -0,0 +1,3 @@ +mixtral-8x7b.Server.target_qps = 1.8 +mixtral-8x7b.Offline.target_qps = 4.0 + diff --git a/mlperf/warmup.py b/mlperf/warmup.py new file mode 100644 index 00000000..4df66551 --- /dev/null +++ b/mlperf/warmup.py @@ -0,0 +1,192 @@ +import argparse +import asyncio +from dataclasses import dataclass, field +from datetime import datetime +import json +import random +import time +from typing import Any, AsyncGenerator, Optional +import os + + +import grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine.token_utils import load_vocab +from jetstream.third_party.llama3 import llama3_tokenizer +import numpy as np +from tqdm.asyncio import tqdm # pytype: disable=pyi-error +import pandas + + +@dataclass +class InputRequest: + prompt: str = "" + prompt_len: int = 0 + output: str = "" + output_len: int = 0 + sample_idx: int = -1 + + +@dataclass +class RequestFuncOutput: + input_request: Optional[InputRequest] = None + generated_token_list: list[str] = field(default_factory=list) + generated_text: str = "" + success: bool = False + latency: float = 0 + ttft: float = 0 + prompt_len: int = 0 + + # Flatten the structure and return only the necessary results + def to_dict(self): + return { + "prompt": self.input_request.prompt, + "original_output": self.input_request.output, + "generated_text": self.generated_text, + "success": self.success, + "latency": self.latency, + "prompt_len": self.prompt_len, + "sample_idx": self.input_request.sample_idx, + } + + +async def grpc_async_request( + api_url: str, request: Any +) -> tuple[list[str], float, float]: + """Send grpc synchronous request since the current grpc server is sync.""" + options = [("grpc.keepalive_timeout_ms", 10000)] + async with grpc.aio.insecure_channel(api_url, options=options) as channel: + stub = jetstream_pb2_grpc.OrchestratorStub(channel) + print("Making request") + ttft = 0 + token_list = [] + request_start_time = time.perf_counter() + response = stub.Decode(request) + async for resp in response: + if ttft == 0: + ttft = time.perf_counter() - request_start_time + token_list.extend(resp.stream_content.samples[0].token_ids) + latency = time.perf_counter() - request_start_time + print("Done request: ", latency) + return token_list, ttft, latency + + +async def send_request( + api_url: str, + tokenizer: Any, + input_request: InputRequest, + pbar: tqdm, + session_cache: str, + priority: int, +) -> RequestFuncOutput: + """Send the request to JetStream server.""" + # Tokenization on client side following MLPerf standard. + token_ids = np.random.randint(0, 1000, input_request.request_len) + request = jetstream_pb2.DecodeRequest( + session_cache=session_cache, + token_content=jetstream_pb2.DecodeRequest.TokenContent( + token_ids=token_ids + ), + priority=priority, + max_tokens=input_request.output_len, + ) + output = RequestFuncOutput() + output.input_request = input_request + output.prompt_len = input_request.prompt_len + generated_token_list, ttft, latency = await grpc_async_request( + api_url, request + ) + output.ttft = ttft + output.latency = latency + output.generated_token_list = generated_token_list + # generated_token_list is a list of token ids, decode it to generated_text. + output.generated_text = "" + output.success = True + if pbar: + pbar.update(1) + return output + + +async def benchmark( + api_url: str, + max_length: int, + tokenizer: Any = None, + request_rate: float = 0, + disable_tqdm: bool = False, + session_cache: str = "", + priority: int = 100, +): + """Benchmark the online serving performance.""" + + print(f"Traffic request rate: {request_rate}") + + benchmark_start_time = time.perf_counter() + tasks = [] + interesting_buckets = [ + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + ] + + for length in interesting_buckets: + if length > max_length: + break + request = InputRequest() + request.request_len = length + print("send request of length", request.request_len) + tasks.append( + asyncio.create_task( + send_request( + api_url=api_url, + tokenizer=None, + input_request=request, + pbar=None, + session_cache=session_cache, + priority=priority, + ) + ) + ) + outputs = await asyncio.gather(*tasks) + + benchmark_duration = time.perf_counter() - benchmark_start_time + return benchmark_duration, outputs + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + api_url = f"{args.server}:{args.port}" + + benchmark_result, request_outputs = asyncio.run( + benchmark(api_url=api_url, max_length=args.max_length) + ) + print("DURATION:", benchmark_result) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Benchmark the online serving throughput." + ) + parser.add_argument( + "--server", + type=str, + default="0.0.0.0", + help="Server address.", + ) + parser.add_argument("--seed", type=int, default=0) + + parser.add_argument("--port", type=str, default=9000) + parser.add_argument("--max-length", type=int, default=512) + + parsed_args = parser.parse_args() + main(parsed_args) diff --git a/run_interactive.py b/run_interactive.py index eef2def8..e5193db3 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -42,11 +42,17 @@ def main(argv): max_output_length = 1024 profiling_output = FLAGS.profiling_output - profiling_prefill = FLAGS.profiling_prefill - if profiling_output and profiling_prefill: - jax.profiler.start_trace(profiling_output) + profiling_prefill = ( + FLAGS.profiling_prefill + and profiling_output is not None + and profiling_output != "" + ) + if profiling_prefill: + jax.profiler.start_trace(profiling_output) decode_state = engine.init_decode_state() + if profiling_prefill: + jax.profiler.stop_trace() prompts: List[str] = [ "I believe the meaning of life is", "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", @@ -62,21 +68,27 @@ def main(argv): print(f"---- Encoded tokens are: {tokens}") # pylint: disable-next=all + if profiling_prefill: + jax.profiler.start_trace(profiling_output) prefill_result, _ = engine.prefill( params=params, padded_tokens=tokens, true_length=true_length ) # pylint: disable-next=all decode_state = engine.insert(prefill_result, decode_state, slot=slot) + if profiling_prefill: + jax.profiler.stop_trace() + sampled_tokens_list = [] print(f"---- Streaming decode started on #slot{slot}.") complete = np.zeros((1,), dtype=np.bool_) while True: - if profiling_output and not profiling_prefill: + if profiling_output: jax.profiler.start_trace(profiling_output) decode_state, result_tokens = engine.generate(params, decode_state) - if profiling_output and not profiling_prefill: - jax.profiler.stop_trace() result_tokens = result_tokens.convert_to_numpy() + + if profiling_output: + jax.profiler.stop_trace() output, complete = token_utils.process_result_tokens( tokenizer=tokenizer, slot=slot, @@ -94,9 +106,6 @@ def main(argv): print("---- All output text.") print(tokenizer.decode(sampled_tokens_list)) - if profiling_output and profiling_prefill: - jax.profiler.stop_trace() - if __name__ == "__main__": os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" diff --git a/tests/helpers.py b/tests/helpers.py index 00442517..62c0789b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,7 +6,7 @@ from jetstream_pt import environment -def make_env_tiny(bf16_enable=True): +def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) @@ -26,6 +26,8 @@ def make_env_tiny(bf16_enable=True): environment_data.cache_sequence_length, config.dim // config.n_heads, ) + environment_data.testing = True + env_data_update_fn(environment_data) env = environment.JetEngineEnvironment(environment_data) env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu return env, config diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index dcbcf5f2..73d0ce6c 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -23,15 +23,16 @@ import torch_xla2 from torch.utils import _pytree as pytree - from jetstream_pt.engine import PyTorchEngine from jetstream_pt.third_party.llama import model_exportable, model_args from jetstream_pt.third_party.llama.generation_original import LlamaOriginal from jetstream_pt import environment from tests import helpers +from jetstream_pt import torchjax +from absl.testing import parameterized -class LlamaE2ETest(unittest.TestCase): +class LlamaE2ETest(parameterized.TestCase): """This test class includes all E2E test for llama2""" def _from_torch(self, tree): @@ -187,6 +188,9 @@ def _llama_e2e(self, env, model_arg): model_ours = model_exportable.Transformer(model_arg, env) + for k, v in model_ours.state_dict().items(): + if "scale" in k: + state_dict[k] = helpers.to_xla_tensor(v) engine = PyTorchEngine(pt_model=model_ours, env=env) params = self._from_torch(state_dict) @@ -233,6 +237,58 @@ def test_llama_e2e_bfloat16(self): out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertNotEqual(out_tokens, expected_output_tokens) + @parameterized.named_parameters( + ("ring_buffer_f32", True, False, False), + ("left_aligned_f32", False, False, False), + ) + def test_llama_e2e_result_verification( + self, ring_buffer, quantized, bf16_enabled + ): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.quant_config.enable_kv_quantization = quantized + + env, model_arg = helpers.make_env_tiny(bf16_enabled, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + @parameterized.named_parameters( + ("ring_buffer_int8", True, True, True), + ("ring_buffer_bf16", True, False, True), + ("left_aligned_int8", False, True, True), + ("left_aligned_bf16", False, False, True), + ) + def test_llama_e2e_no_result_verification( + self, ring_buffer, quantized, bf16_enabled + ): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.quant_config.enable_kv_quantization = quantized + + env, model_arg = helpers.make_env_tiny(bf16_enabled, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertNotEqual(out_tokens, expected_output_tokens) + # pylint: disable-next=all def test_llama_e2e_two_addtional_tokens(self): """end to end jetstream llama with addtional tokens""" diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 4d4ddfd6..703ce444 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -65,9 +65,13 @@ def _make_freqs_cis(self, model_arg, seqlen, start_pos): freqs_cis = freqs_cis[start_pos : start_pos + seqlen] return freqs_cis - def _generate_mask(self, cache_length, pos, seqlen): + def _generate_mask(self, cache_length, pos, seqlen, ring_buffer=True): x = jnp.arange(0, cache_length) - cond = jnp.logical_and(x <= pos, x >= pos - seqlen) + if ring_buffer: + cond = jnp.logical_and(x <= pos, x >= pos - seqlen) + else: + # Left aligned buffer we postpone the cache update + cond = jnp.logical_and(x < pos, x >= pos - seqlen) res = jnp.where(cond, 0, float("-inf")) return torchjax.to_torch(res) @@ -91,6 +95,7 @@ def _make_one_cache_for_generate(self, env, pos): # pylint: disable-next=all def test_attention(self): + torch.manual_seed(0) env, model_arg = helpers.make_env_tiny(False) attention_orig = model_original.Attention(model_arg) @@ -101,6 +106,7 @@ def test_attention(self): hidden_size=model_arg.dim, device="cpu", env=env, + layer_id=0, ) seqlen = 32 @@ -136,11 +142,11 @@ def test_attention(self): # insert prefilled cache entry cache_decode.cache_k._elem = cache_decode.cache_k._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_k._elem) cache_decode.cache_v._elem = cache_decode.cache_v._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_v._elem) # self._compare_cache(attention_orig.cache_k, cache_decode.cache_k) @@ -154,7 +160,7 @@ def test_attention(self): None, # mask is none for decode ) expected_out = attention_orig(*inputs_orig2) - cache_decode.pos = [pos] # next position to update + cache_decode.input_pos = [pos] # next position to update mask = self._generate_mask(env.cache_sequence_length, pos, seqlen) mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one freqs_cis = freqs_cis.reshape(batch, 1, -1) @@ -203,6 +209,7 @@ def init_weights(model): head_dim=head_dim, device="meta", env=env, + layer_id=0, ) def load_hook(state_dict, prefix, *args): @@ -228,8 +235,8 @@ def load_hook(state_dict, prefix, *args): freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos) mask = self._prefill_mask(seqlen, start_pos) kv_write_indexes = torch.arange(0, seqlen) - cache_k = torch.zeros((batch, seqlen, num_heads, head_dim)) - cache_v = torch.zeros((batch, seqlen, num_heads, head_dim)) + cache_k = torch.zeros((batch, seqlen, num_kv_heads, head_dim)) + cache_v = torch.zeros((batch, seqlen, num_kv_heads, head_dim)) inputs_orig = (x, freqs_cis, kv_write_indexes, (cache_k, cache_v), mask) expected_out = attention_orig(*inputs_orig) @@ -300,10 +307,10 @@ def test_transformer_block(self): # insert prefilled cache entry cache_decode.cache_k._elem = cache_decode.cache_k._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_k._elem) cache_decode.cache_v._elem = cache_decode.cache_v._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_v._elem) # Now do one with decode @@ -316,7 +323,7 @@ def test_transformer_block(self): None, # mask is none for decode ) expected_out = block_orig(*inputs_orig2) - cache_decode.pos = [pos] # next position to update + cache_decode.input_pos = [pos] # next position to update mask = self._generate_mask(env.cache_sequence_length, pos, seqlen) mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one freqs_cis = freqs_cis.reshape(batch, 1, -1) diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 190e6fda..f48809ea 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -22,7 +22,7 @@ import torch import torch_xla2 from jax.experimental import mesh_utils -from jetstream_pt import cache_manager, layers, quantize, torchjax +from jetstream_pt import cache_manager, layers, quantize, torchjax, environment from jetstream_pt.environment import QuantizationConfig from jetstream_pt.layers import ( WeightOnlyBlockwiseQuantizedLinear, @@ -33,11 +33,13 @@ from tests import helpers from torch.utils import _pytree as pytree from torch_xla2 import tensor +import copy +from absl.testing import parameterized torch.manual_seed(12345) -class QuantizationTest(unittest.TestCase): +class QuantizationTest(parameterized.TestCase): """test kv cache quantization""" def _xla_tensor(self, shape): @@ -70,72 +72,216 @@ def _print_diff(self, w, w_dq): print(" norm: ", (w - w_dq).norm()) print(" cosine dist: ", self._calc_cosine_dist(w, w_dq)) - def test_kv_cache(self): + @parameterized.named_parameters( + ("ring_buffer", True), + ("left_aligned", False), + ) + def test_kv_cache(self, ring_buffer): """test kv cache quantization""" - cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim + + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.quant_config.enable_kv_quantization = True + env_data.batch_size = 4 + + env, _ = helpers.make_env_tiny(True, update_env_data) + + batch = env.batch_size + if env.generate_cache_stacked: + cache_shape = ( + env.num_layers, + batch, + 2, + 100, + 2, + ) # layer, bs, num heads, seqlen, dim + else: + cache_shape = (batch, 2, 100, 2) # bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - env, _ = helpers.make_env_tiny() - cache = cache_manager.Int8KVCacheGenerate.empty( - cache_shape, None, False, env - ) - # seqlen is 1 - k = self._xla_tensor((3, 2, 1, 2)) - v = self._xla_tensor((3, 2, 1, 2)) - cache.input_pos = [57] - new_k, new_v, scaler_k, scaler_v = cache.update(k, v) - new_k = new_k * scaler_k - new_v = new_v * scaler_v + cache = cache_manager.Int8KVCacheGenerate.empty(cache_shape, None, env) + # seqlen is 1 + k = self._xla_tensor((batch, 2, 1, 2)) + v = self._xla_tensor((batch, 2, 1, 2)) + + def update_finalize_compare(in_k, in_v, in_layer, in_pos): + cache.input_pos = ( + [in_pos] if env.ring_buffer else jnp.array([in_pos] * batch) + ) + + # layer id may or may not take effect, depends on the env config. + cache.update(in_k, in_v, layer_id=in_layer) + cache.finalize() + if env.quant_config.enable_kv_quantization: + new_k = cache.cache_k * cache.k_scaler + new_v = cache.cache_v * cache.v_scaler + else: + new_k = cache.cache_k + new_v = cache.cache_v + + if env.generate_cache_stacked: + self.assertTrue( + jnp.allclose( + k._elem, + new_k._elem[in_layer, :, :, in_pos : (in_pos + 1), :], + atol=0.1, + ) + ) + self.assertTrue( + jnp.allclose( + v._elem, + new_v._elem[in_layer, :, :, in_pos : (in_pos + 1), :], + atol=0.1, + ) + ) + else: + self.assertTrue( + jnp.allclose( + k._elem, new_k._elem[:, :, in_pos : (in_pos + 1), :], atol=0.1 + ) + ) + self.assertTrue( + jnp.allclose( + v._elem, new_v._elem[:, :, in_pos : (in_pos + 1), :], atol=0.1 + ) + ) + + update_finalize_compare(k, v, in_layer=1, in_pos=57) + update_finalize_compare(k, v, in_layer=1, in_pos=58) + update_finalize_compare(k, v, in_layer=2, in_pos=3) + + @parameterized.named_parameters( + ("ring_buffer", True), + ("left_aligned", False), + ) + def test_kv_kernel(self, ring_buffer): + """test kv cache quantization""" - self.assertTrue( - jnp.allclose(k._elem, new_k._elem[:, :, 57:58, :], atol=0.1) - ) - self.assertTrue( - jnp.allclose(v._elem, new_v._elem[:, :, 57:58, :], atol=0.1) - ) + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.quant_config.enable_kv_quantization = True + env_data.batch_size = 4 + + env, _ = helpers.make_env_tiny(False, update_env_data) + + batch = env.batch_size + if env.generate_cache_stacked: + cache_shape = ( + env.num_layers, + batch, + 2, + 100, + 2, + ) # bs, num heads, seqlen, dim + else: + cache_shape = (batch, 2, 100, 2) # layers, bs, num heads, seqlen, dim - def test_kv_kernel(self): - """test kv cache quantization""" - cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - env, _ = helpers.make_env_tiny(False) + key = jax.random.PRNGKey(123) key2 = jax.random.PRNGKey(456) - cache_k_jax = jax.random.normal(key, cache_shape) - cache_v_jax = jax.random.normal(key2, cache_shape) + cache_k_jax = jax.random.normal(key, cache_shape, dtype=env.default_type) + cache_v_jax = jax.random.normal(key2, cache_shape, dtype=env.default_type) - cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) + start = jnp.zeros((batch,), dtype=jnp.int32) - cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None, env) + cache_k, cache_v, start = torchjax.to_torch( + (cache_k_jax, cache_v_jax, start) + ) + + # Prepare quantized cache before written in + cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (-3, -1)) + cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (-3, -1)) # 1 is seqlen - xq = jax.random.normal(key, (3, 2, 1, 2)) - xk = jax.random.normal(key, (3, 2, 1, 2)) - xv = jax.random.normal(key, (3, 2, 1, 2)) + xq = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) + xk = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) + xv = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) xq, xk, xv = torchjax.to_torch((xq, xk, xv)) - attention_float = layers.AttentionKernel(env) - float_res = attention_float(xq, xk, xv, None, cache) + def get_var(position: int): + pos = ( + [position] + if env.ring_buffer + else jnp.array([position] * batch, dtype=jnp.int64) + ) + mask = jax.lax.broadcast_in_dim( + jnp.array([0] * position + [float("-inf")] * (100 - position)), + (env.batch_size, 1, 1, 100), + (3,), + ) + mask = torchjax.to_torch((mask)) + return pos, mask + + cache = cache_manager.KVCacheGenerate(cache_k, cache_v, None, None, env) + # layer_id doesn't matter, will assign later + attention_float = layers.AttentionKernel(env, layer_id=0) + + float_res = [] + + def update_finalize_record( + in_attention, in_cache, in_q, in_k, in_v, in_layer, in_pos + ): + pos, mask = get_var(in_pos) + in_attention.layer_id = in_layer + in_cache.input_pos = pos + ret = in_attention( + in_q, in_k, in_v, mask, in_cache, start=start, end=pos + ) + in_cache.finalize() + return ret + + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 1, 57) + ) + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 1, 58) + ) + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 2, 3) + ) - # == + # Running into the issue of multiple env object always share the same quant_config. + # Record the results and compare as a workaround. + env._data.quant_config.enable_kv_quantization = True + env = environment.JetEngineEnvironment(env._data) - cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) - cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (1, 3)) - cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (1, 3)) cache_int = cache_manager.Int8KVCacheGenerate( cache_k_int, cache_v_int, cache_k_scaler, cache_v_scaler, - [0], + None, None, env, ) - attention_quant = layers.Int8KVAttentionKernel(env) - int_res = attention_quant(xq, xk, xv, None, cache_int) + # layer_id doesn't matter, will assign later + attention_quant = layers.Int8KVAttentionKernel(env, layer_id=0) + + int_res = [] + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 57) + ) + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 58) + ) + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 2, 3) + ) - self.assertTrue(jnp.allclose(float_res.jax(), int_res.jax(), atol=0.01)) + for f, i in zip(float_res, int_res): + self.assertTrue(jnp.allclose(f.jax(), i.jax(), atol=0.01)) def test_quantize_dequantize_tensor(self):