diff --git a/README.md b/README.md index 5292e0c7..ca6ec4ba 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,7 @@ Note: Get address ip and port information from ray head. Here is an example to run the server with ray for llama2 7B model: ```bash +export DISABLE_XLA2_PJRT_TEST="true" python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" ``` diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 2abda049..1fdc0cb7 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -32,7 +32,7 @@ flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file") -def run_prefill_time(engine, params, decode_state, seqlen): +def run_prefill_time(engine, params, decode_state, seqlen, profiler_started): """Run prefill and measure time.""" metadata = engine.get_tokenizer() tokenizer = engine.build_tokenizer(metadata) @@ -53,6 +53,10 @@ def run_prefill_time(engine, params, decode_state, seqlen): nums = 5 start = time.perf_counter() for i in range(nums): + if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started: + jax.profiler.start_trace(FLAGS.profiling_output) + profiler_started = True + prefill_result, _ = engine.prefill( params=params, padded_tokens=tokens, true_length=true_length ) @@ -60,8 +64,9 @@ def run_prefill_time(engine, params, decode_state, seqlen): prefill_result, decode_state, slot=jnp.int32(i) ) jax.block_until_ready(decode_state) + end = time.perf_counter() - return (end - start) / nums, decode_state + return (end - start) / nums, decode_state, profiler_started MAXTEXT_PREFILL = { @@ -86,9 +91,10 @@ def main(argv): prefill_times = {} decode_state = engine.init_decode_state() + profiler_started = False for batch, _ in MAXTEXT_PREFILL.items(): - runtime, decode_state = run_prefill_time( - engine, params, decode_state, batch + runtime, decode_state, profiler_started = run_prefill_time( + engine, params, decode_state, batch, profiler_started ) prefill_times[batch] = runtime @@ -103,10 +109,12 @@ def main(argv): profiling_output = FLAGS.profiling_output print("======= decode starting ===") + dec_times = [] for i in range(10): - if profiling_output and i == 7: + if profiling_output and i == 7 and not profiler_started: jax.profiler.start_trace(profiling_output) + profiler_started = True start = time.perf_counter() # pylint: disable-next=all decode_state, sampled_tokens = engine.generate(params, decode_state) @@ -116,7 +124,7 @@ def main(argv): dec_times.append(end - start) print(i, "decode time", (end - start)) - if profiling_output: + if profiler_started: jax.profiler.stop_trace() print("prefill ", prefill_times) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 96bb4233..6d571d2c 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -16,6 +16,7 @@ def ragged_flash_attention_kernel( + layer_ref, start_ref, end_ref, line_end_ref, @@ -105,30 +106,45 @@ def run(): @functools.partial( - jax.jit, static_argnames=["bk", "mask_value", "normalize_var"] + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "testing", + "quantized", + ], ) def ragged_mqa( q: jax.Array, k: jax.Array, v: jax.Array, + layer, start: jax.Array, end: jax.Array, - k_scaler: jax.Array | None = None, - v_scaler: jax.Array | None = None, ragged_batch_index=None, ragged_block_index=None, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, + testing: bool = False, + quantized: bool = False, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi query attention.""" with jax.named_scope("ragged_mqa"): - batch_size, num_heads, head_dim = q.shape - seq_len = k.shape[1] + batch_size, time, head_dim = q.shape + seq_len = k.shape[-2] + + stacked = False + if k.ndim == 5: + stacked = True def kv_index_map( b, i, + layer_ref, start_ref, end_ref, line_end_ref, @@ -136,11 +152,20 @@ def kv_index_map( ragged_block_index_ref, ): index = b * (seq_len // bk) + i + + if stacked: + return ( + layer_ref[0], + ragged_batch_index_ref[index], + ragged_block_index_ref[index], + 0, + ) return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 def q_index_map( b, i, + layer_ref, start_ref, end_ref, line_end_ref, @@ -148,17 +173,32 @@ def q_index_map( ragged_block_index_ref, ): index = b * (seq_len // bk) + i + if stacked: + return layer_ref[0], ragged_batch_index_ref[index], 0, 0 return ragged_batch_index_ref[index], 0, 0 - def scaler_index_map(b, i, *_): + def scaler_index_map(b, i, layer_ref, *_): + if stacked: + return layer_ref[0], b, 0, i return b, 0, i line_end = jnp.where(start < end, end, seq_len - 1) + if stacked: + q_bp = (None, None, time, head_dim) + kv_bp = (None, None, bk, head_dim) + ks_bp = (None, None, 1, bk) + else: + q_bp = (None, time, head_dim) + kv_bp = (None, bk, head_dim) + ks_bp = (None, 1, bk) + in_specs = [ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(q_index_map, q_bp), + pl.BlockSpec(kv_index_map, kv_bp), + pl.BlockSpec(kv_index_map, kv_bp), + pl.BlockSpec(scaler_index_map, ks_bp), + pl.BlockSpec(scaler_index_map, ks_bp), ] inputs = ( start, @@ -169,15 +209,9 @@ def scaler_index_map(b, i, *_): q, k, v, + k_scaler, + v_scaler, ) - quantized = False - if k_scaler is not None: - in_specs = in_specs + [ - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - ] - inputs = inputs + (k_scaler, v_scaler) - quantized = True out, m, l = pl.pallas_call( functools.partial( @@ -191,33 +225,241 @@ def scaler_index_map(b, i, *_): num_scalar_prefetch=5, in_specs=in_specs, out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), ], grid=(batch_size, seq_len // bk), ), compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, + interpret=testing, out_shape=[ q, - jax.ShapeDtypeStruct( - (batch_size, num_heads, head_dim), jnp.float32 - ), - jax.ShapeDtypeStruct( - (batch_size, num_heads, head_dim), jnp.float32 - ), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), ], )(*inputs) return out, (m[..., 0], l[..., 0]) +def ragged_mqa_kernel_reference( + layer_ref, + start_ref, + end_ref, + line_end_ref, + pre_b_ref, + pre_i_ref, + q_ref, + k_ref, + v_ref, + k_scaler_ref, + v_scaler_ref, + o_ref, + m_ref, + l_ref, + bk: int, + mask_value: float, + normalize_var: bool, + quantized: bool, +): + """Pallas kernel for ragged attention.""" + b, i = pl.program_id(0), pl.program_id(1) + del layer_ref + + @pl.when(i == 0) + def init(): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + # length = lengths_ref[b] + # Always start from 0, left aligned + length = end_ref[b] + + @pl.when(i * bk < length) + def run(): + q = q_ref[...].astype(jnp.float32) + k = k_ref[...].astype(jnp.float32) + v = v_ref[...].astype(jnp.float32) + m_prev, l_prev = m_ref[...], l_ref[...] + + qk = jax.lax.dot_general( + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 + ) + + if normalize_var: + qk = qk / math.sqrt(k.shape[-1]) # Align with meta llama + # Quantized + if quantized: + qk = qk * k_scaler_ref[...] + + mask = i * bk + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length + qk = qk + jnp.where(mask, 0.0, mask_value) + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + # Quantized + if quantized: + s_curr = s_curr * v_scaler_ref[...] + + o_curr_times_l_curr = jnp.dot(s_curr, v) + + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + ).astype(o_ref.dtype) + + @functools.partial( - jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"] + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "testing", + "quantized", + ], +) +def ragged_mqa_reference( + q: jax.Array, + k: jax.Array, + v: jax.Array, + layer, + start: jax.Array, + end: jax.Array, + ragged_batch_index=None, + ragged_block_index=None, + k_scaler: jax.Array = None, + v_scaler: jax.Array = None, + bk: int = 512, + mask_value: float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, + testing: bool = False, + quantized: bool = False, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi query attention.""" + batch_size, time, head_dim = q.shape + # assert end.shape == (batch_size,) + seq_len = k.shape[-2] + + stacked = False + if k.ndim == 4: + stacked = True + + def _compute_ragged_block_indices(b, i, lengths_ref): + length = lengths_ref[b] + not_done = i * bk < length + am_last_batch = b == batch_size - 1 + # if length < bk, then it's -1, should be 0? + last_good_block = jax.lax.div(length, bk) - 1 + + # if not done, then still work on b, otherwise next batch + b_next = jnp.where(not_done, b, jnp.where(am_last_batch, b, b + 1)) + # if not done, i next = i + # if done + # if last batch, previous good block + # if not last batch, i next = 0 + i_next = jnp.where( + not_done, i, jnp.where(am_last_batch, last_good_block, 0) + ) + return b_next, i_next + + def kv_index_map(b, i, layer_ref, start_ref, end_ref, *_): + b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) + if stacked: + return layer_ref[0], b_next, i_next, 0 + return b_next, i_next, 0 + + def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_): + b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) + if stacked: + return layer_ref[0], b_next, 0, i_next + return b_next, 0, i_next + + if stacked: + kv_bp = (None, None, bk, head_dim) + ks_bp = (None, None, 1, bk) + else: + kv_bp = (None, bk, head_dim) + ks_bp = (None, 1, bk) + + in_specs = [ + pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, time, head_dim)), # q + pl.BlockSpec(kv_index_map, kv_bp), # k + pl.BlockSpec(kv_index_map, kv_bp), # v + pl.BlockSpec(kv_scale_index_map, ks_bp), # k_scaler + pl.BlockSpec(kv_scale_index_map, ks_bp), # v_scaler + ] + + inputs = ( + jnp.array([layer]), + start, + end, + end, # line_end, not actually used + ragged_batch_index, + ragged_block_index, + q, + k, + v, + k_scaler, + v_scaler, + ) + + out, m, l = pl.pallas_call( + functools.partial( + ragged_mqa_kernel_reference, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=quantized, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=6, + in_specs=in_specs, + out_specs=[ + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + interpret=testing, + # debug=True, + compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, + out_shape=[ + q, + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + ], + )(*inputs) + return out, (m[..., 0], l[..., 0]) + + +@functools.partial( + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "q_shard_axis", + "kv_shard_axis", + "testing", + ], ) def ragged_mha( q: jax.Array, k: jax.Array, v: jax.Array, + layer, start: jax.Array, end: jax.Array, ragged_batch_index: jax.Array, @@ -227,7 +469,9 @@ def ragged_mha( bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, - shard_axis: int = 1, + q_shard_axis: int = 0, + kv_shard_axis: int = 0, + testing: bool = False, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi head attention. Args: @@ -251,35 +495,66 @@ def ragged_mha( softmax denominator ([batch_size, num_heads, compute_dim, 1]). """ mask_value = DEFAULT_MASK_VALUE + bk = min(bk, k.shape[-2]) + bq, hq, tq, dq = q.shape + hkv = k.shape[-3] + tk = k.shape[-2] + + assert k.shape[-1] == q.shape[-1] + assert k.shape[-4] == q.shape[-4] + + rep = hq // hkv + if rep > 1: + q = q.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq) + stacked = k.ndim == 5 + + replicated_in_axes = 7 if k_scaler is None: - replicated_in_axes = 4 - replicated_inputs = (ragged_batch_index, ragged_block_index) + quantized = False + if k.ndim == 5: + kv_scale_shape = (k.shape[0], bq, 1, tk) + else: + kv_scale_shape = (bq, 1, tk) + k_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16) + v_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16) else: - replicated_in_axes = 6 - replicated_inputs = ( - jnp.squeeze(k_scaler, -1), - jnp.squeeze(v_scaler, -1), - ragged_batch_index, - ragged_block_index, - ) + quantized = True + k_scale = jnp.squeeze(k_scaler, -1) + v_scale = jnp.squeeze(v_scaler, -1) + + if stacked: + assert k_scale.shape == (k.shape[0], bq, 1, tk) + else: + assert k_scale.shape == (bq, 1, tk) + + replicated_inputs = ( + ragged_batch_index, + ragged_block_index, + k_scale, + v_scale, + ) + # New cache has t=1 with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( functools.partial( - ragged_mqa, + # ragged_mqa, + ragged_mqa_reference, bk=bk, mask_value=mask_value, normalize_var=normalize_var, + testing=testing, + quantized=quantized, # out_dtype=out_dtype, ), in_axes=( - shard_axis, - shard_axis, - shard_axis, + q_shard_axis, + kv_shard_axis, + kv_shard_axis, *([None] * replicated_in_axes), ), - out_axes=shard_axis, - )(q, k, v, start, end, *replicated_inputs) + out_axes=q_shard_axis, + )(q, k, v, layer, start, end, *replicated_inputs) return out, (m, l) @@ -310,15 +585,77 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): return output +def flash_attention( + xq, + keys, + values, + layer, + k_scaler=None, + v_scaler=None, + mask=None, + normalize_var=True, +): + """Flash attention kernel.""" + if keys.ndim == 5: + keys = keys[layer] + values = values[layer] + k_scaler = k_scaler[layer] if k_scaler is not None else None + v_scaler = v_scaler[layer] if v_scaler is not None else None + + logits = torch.einsum( + "bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32) + ) + + if normalize_var: + logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama + # Quantized + if k_scaler is not None: + logits = logits * k_scaler.reshape( + k_scaler.shape[-4], 1, 1, k_scaler.shape[-2] + ) + + # mask = jnp.arange(keys.shape[1])[None] < lengths[:, None] + if mask is not None: + # logits = logits + jnp.where(mask, 0.0, DEFAULT_MASK_VALUE)[:, None] + logits = logits + mask + + logits_max, _ = torch.max(logits, axis=-1, keepdim=True) + unnormalized = torch.exp(logits - logits_max) + # Quantized, should not put here, otherwise sum will have this too, which cancels with denominator + # unnormalized = unnormalized * v_scaler + + denominator = unnormalized.sum(axis=-1, keepdim=True) + if v_scaler is not None: + unnormalized = unnormalized * v_scaler.reshape( + v_scaler.shape[-4], 1, 1, v_scaler.shape[-2] + ) + o = ( + torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values) + / denominator + ) + + return o, (logits_max, denominator) + + class RaggedAttentionKernel: """Ragged attention kernel.""" - def __init__(self, env, input_specs, output_specs, sharding_axis): + def __init__( + self, env, input_specs, output_specs, q_shard_axis, kv_shard_axis + ): self.binded_ragged_mha = functools.partial( - ragged_mha, bk=env.block_size, shard_axis=sharding_axis + ragged_mha, + bk=env.block_size, + q_shard_axis=q_shard_axis, + kv_shard_axis=kv_shard_axis, + testing=env.testing, ) self.binded_ragged_mha = shard_map( - ragged_mha, env.mesh, input_specs, output_specs, check_rep=False + self.binded_ragged_mha, + env.mesh, + input_specs, + output_specs, + check_rep=False, ) self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) @@ -327,6 +664,7 @@ def __call__( xq, keys, values, + layer, start, end, ragged_batch_index, @@ -338,6 +676,7 @@ def __call__( xq, keys, values, + layer, start, end, ragged_batch_index, diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 13789f91..76f44120 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -14,7 +14,10 @@ import jax import jax.numpy as jnp +from jax.experimental.shard_map import shard_map import torch +import torch_xla2 + from jetstream_pt import torchjax @@ -38,19 +41,20 @@ def update(self, key, value): class KVCachePrefill: """Prefill kv cache""" - def __init__(self, kv_quantize=False): + def __init__(self, kv_quantize=False, stacked=False): self.kv_quantize = kv_quantize self.cache_k = None self.cache_v = None + self.stacked = stacked - def update(self, key, value): + def update(self, key, value, layer_id): """This cache just remembers the stuff.""" self.cache_k = key self.cache_v = value if self.kv_quantize: # pretend to be quantized bsz, _, seq, _ = key.shape ones = torchjax.to_torch(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16)) - return key, value, ones, ones + return key, value, None, None, ones, ones, None, None return key, value @@ -58,6 +62,11 @@ def state(self): """Get prefill cache state""" return self.cache_k, self.cache_v + # Placeholder, to match with GenerateCache + def finalize(self): + """Finalize the cache operation and updates the cache.""" + return + # pylint: disable-next=all def KVCachePrefill_flatten(cache): @@ -80,57 +89,225 @@ def KVCachePrefill_unflatten(auxdata, data): ) -# Refactor out cache management -# Easier to test for quantized kv cache class KVCacheGenerate: """Kvache generator without quantization""" + # pylint: disable=too-many-instance-attributes + # More than 7 is reasonable in this case. def __init__( self, cache_k: torch.Tensor, # previous cache cache_v: torch.Tensor, # previous cache - position: int, # position to store the cache + position: int | torch.Tensor, # position to store the cache sharding, env=None, ): super().__init__() self.cache_k = cache_k self.cache_v = cache_v - self.pos = position + self.input_pos = position self.sharding = sharding self.env = env - def update(self, key, value): + self.new_ks = None + self.new_vs = None + self.env = env + # Keep this one it's used in the specific model code. + self.stacked = env.generate_cache_stacked + self.batch = jnp.arange(self.env.batch_size) + # The other way is to store the list and loop over to insert in finalize() + if self.env.lazy_cache_update: + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + layer, batch, heads, _, dim = self.cache_k.shape + new_dim = (layer, batch, heads, 1, dim) + self.new_ks, self.new_vs = torchjax.to_torch( + ( + jnp.zeros(new_dim, dtype=self.env.default_type), + jnp.zeros(new_dim, dtype=self.env.default_type), + ) + ) + else: + self.new_ks, self.new_vs = [], [] + else: # when generate cache is not stacked, new cache cannot stack + assert not self.env.new_cache_stacked + + cache_pspec = self.env.partition_by_axis( + self.env.cache_sharding_axis + ) # Number of heads + none_pspec = self.env.partition_by_axis() + in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec, none_pspec) + out_specs = (cache_pspec, cache_pspec) + self.update_single_cache_line = jax.jit( + shard_map( + self.update_single_cache_line, + self.env.mesh, + in_specs, + out_specs, + check_rep=False, + ) + ) + + # pylint: disable=method-hidden + # False alarm. The jit above doesn't hide this method. + def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, pos): + """The shard map version of single cache line update.""" + b = cache_k.shape[-4] + for bb, pp in enumerate(pos.reshape(b)): + slice_dim = 0 + update_start_indices = (bb, 0, pp, 0) + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + slice_dim = 1 + update_start_indices = (0, bb, 0, pp, 0) + # We are not handling generate_cache_stacked=True new_cache_stacked=False here + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + return cache_k, cache_v + + def finalize(self): + """Finalize the cache operation and updates the cache.""" + if not self.env.lazy_cache_update: + return + + if self.env.ring_buffer: + # Assume no cache stack for ring buffer + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax()) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax()) + ) + else: + if self.env.generate_cache_stacked: + _, b, head, _, dim = self.cache_k.shape + if self.env.new_cache_stacked: + self.cache_k, self.cache_v = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.input_pos, + ) + else: + for i in range(self.env.num_layers): + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax() + .at[i, self.batch, :, self.input_pos, :] + .set(self.new_ks[i].jax().reshape(b, head, dim)) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax() + .at[i, self.batch, :, self.input_pos, :] + .set(self.new_vs[i].jax().reshape(b, head, dim)) + ) + else: + # Try to use shard_map to get rid of the data copy + self.cache_k, self.cache_v = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.input_pos, + ) + + def update(self, key, value, layer_id: int): """Update kv cache""" keyj, valuej = torchjax.to_torch((key, value)) + if self.env.lazy_cache_update: + if self.env.new_cache_stacked: + assert ( + self.env.generate_cache_stacked + ), "When new cache stacked, must have generate_cache_stacked!" + self.new_ks[layer_id, ...] = keyj + self.new_vs[layer_id, ...] = valuej + return self.cache_k[layer_id], self.cache_v[layer_id] + + # Generate cache stacked, but new cache unstacked + if self.env.generate_cache_stacked: + self.new_ks.append(keyj) + self.new_vs.append(valuej) + return self.cache_k[layer_id], self.cache_v[layer_id] + + # all cache unstacked + self.new_ks = keyj + self.new_vs = valuej + return self.cache_k, self.cache_v + if self.env.ring_buffer: + assert ( + not self.env.new_cache_stacked and not self.env.generate_cache_stacked + ), "Ring buffer doesn't support stacked cache." # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj) + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(keyj) + ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej) - else: - batch = jnp.arange(self.env.batch_size) + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(valuej) + ) + return self.cache_k, self.cache_v + + # Non lazy cache update, non ring buffer, generate cache stacked + if self.env.generate_cache_stacked: # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos].set( - keyj.squeeze(2) + self.cache_k._elem = ( + self.cache_k.jax() + .at[layer_id, self.batch, :, self.input_pos, :] + .set(keyj.squeeze(2)) ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos].set( - valuej.squeeze(2) + self.cache_v._elem = ( + self.cache_v.jax() + .at[layer_id, self.batch, :, self.input_pos, :] + .set(valuej.squeeze(2)) ) + return self.cache_k[layer_id], self.cache_v[layer_id] + + # Non lazy cache update, non ring buffer, generate cache non stacked + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax() + .at[self.batch, :, self.input_pos, :] + .set(keyj.squeeze(2)) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax() + .at[self.batch, :, self.input_pos, :] + .set(valuej.squeeze(2)) + ) return self.cache_k, self.cache_v def state(self): """Get kv cache state""" - # pylint: disable-next=all return self.cache_k.jax(), self.cache_v.jax() @classmethod - def empty(cls, shape, device, bf16_enable, env): + def empty(cls, shape, device, env): """Create empty kv caches""" - default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32 - k = jnp.zeros(shape, device=device, dtype=default_dtype) - v = jnp.zeros(shape, device=device, dtype=default_dtype) + default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 + in_shape = shape + if env.testing: + key = jax.random.key(env.testing_seed) + k_key, v_key = jax.random.split(key) + k = jax.random.uniform(k_key, shape=in_shape, dtype=default_dtype) + v = jax.random.uniform(v_key, shape=in_shape, dtype=default_dtype) + else: + k = jnp.zeros(in_shape, device=device, dtype=default_dtype) + v = jnp.zeros(in_shape, device=device, dtype=default_dtype) k, v = torchjax.to_torch((k, v)) return cls(k, v, 0, device, env=env) @@ -159,7 +336,8 @@ def KVCacheGenerate_unflatten(auxdata, data): class Int8KVCacheGenerate: """Int8 quantized kvache with scalers""" - # pylint: disable-next=all + # pylint: disable=too-many-instance-attributes + # More than 7 is reasonable in this case. def __init__( self, cache_k, @@ -175,9 +353,153 @@ def __init__( self.cache_v = cache_v self.k_scaler = cache_k_scaler self.v_scaler = cache_v_scaler + self.new_ks = None + self.new_vs = None + self.new_k_scaler = None + self.new_v_scaler = None + + self.batch = jnp.arange(env.batch_size) self.input_pos = input_pos self.sharding = sharding self.env = env + self.stacked = env.generate_cache_stacked + + if self.env.lazy_cache_update: + if self.env.generate_cache_stacked: + layer, batch, heads, _, dim = self.cache_k.shape + new_kv_dim = (layer, batch, heads, 1, dim) + self.new_ks, self.new_vs = torchjax.to_torch( + ( + jnp.zeros(new_kv_dim, dtype=jnp.int8), + jnp.zeros(new_kv_dim, dtype=jnp.int8), + ) + ) + if self.env.new_cache_stacked: + new_scale_dim = (layer, batch, 1, 1, 1) + self.new_k_scaler, self.new_v_scaler = torchjax.to_torch( + ( + jnp.zeros(new_scale_dim, dtype=self.env.default_type), + jnp.zeros(new_scale_dim, dtype=self.env.default_type), + ) + ) + else: + self.new_ks, self.new_vs, self.new_k_scaler, self.new_v_scaler = ( + [], + [], + [], + [], + ) + else: # when generate cache is not stacked, new cache cannot stack + assert not self.env.new_cache_stacked + + cache_pspec = self.env.partition_by_axis( + self.env.cache_sharding_axis + ) # Number of heads + new_cache_pspec = ( + self.env.partition_by_axis(2) + if self.env.new_cache_stacked + else self.env.partition_by_axis(1) + ) + none_pspec = self.env.partition_by_axis() + in_specs = ( + *([cache_pspec] * 2), + *([new_cache_pspec] * 2), + *([none_pspec] * 5), + ) + out_specs = (cache_pspec, cache_pspec, none_pspec, none_pspec) + self.update_single_cache_line = shard_map( + self.update_single_cache_line, + self.env.mesh, + in_specs, + out_specs, + check_rep=False, + ) + self.update_single_cache_line = jax.jit(self.update_single_cache_line) + + # pylint: disable=method-hidden + # False alarm. The jit above doesn't hide this method. + def update_single_cache_line( + self, + cache_k, + cache_v, + new_ks, + new_vs, + k_scaler, + v_scaler, + new_k_scaler, + new_v_scaler, + pos, + ): + """The shard map version of single cache line update.""" + b = cache_k.shape[-4] + + for bb, pp in enumerate(pos.reshape(b)): + slice_dim = 0 + update_start_indices = (bb, 0, pp, 0) + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + slice_dim = 1 + update_start_indices = (0, bb, 0, pp, 0) + if self.env.generate_cache_stacked and not self.env.new_cache_stacked: + for layer in range(self.env.num_layers): + update_start_indices = (layer, bb, 0, pp, 0) + new_ks_slice = jax.lax.dynamic_slice_in_dim( + new_ks[layer], bb, 1, slice_dim + ) + new_ks_slice = jnp.expand_dims(new_ks_slice, 0) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + + new_vs_slice = jax.lax.dynamic_slice_in_dim( + new_vs[layer], bb, 1, slice_dim + ) + new_vs_slice = jnp.expand_dims(new_vs_slice, 0) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_k_scaler[layer], bb, 1, slice_dim + ) + new_k_scaler_slice = jnp.expand_dims(new_k_scaler_slice, 0) + k_scaler = jax.lax.dynamic_update_slice( + k_scaler, new_k_scaler_slice, update_start_indices + ) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_v_scaler[layer], bb, 1, slice_dim + ) + new_v_scaler_slice = jnp.expand_dims(new_v_scaler_slice, 0) + v_scaler = jax.lax.dynamic_update_slice( + v_scaler, new_v_scaler_slice, update_start_indices + ) + else: + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_k_scaler, bb, 1, slice_dim + ) + k_scaler = jax.lax.dynamic_update_slice( + k_scaler, new_k_scaler_slice, update_start_indices + ) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_v_scaler, bb, 1, slice_dim + ) + v_scaler = jax.lax.dynamic_update_slice( + v_scaler, new_v_scaler_slice, update_start_indices + ) + + return cache_k, cache_v, k_scaler, v_scaler def state(self): """Get kv cache state""" @@ -189,13 +511,17 @@ def scalers(self): @classmethod # pylint: disable-next=all - def empty(cls, shape, device, bf16_enable, env): + def empty(cls, shape, device, env): """Create empty kv caches""" cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8) cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8) - # bf16_enable is a placeholder parameter, it's not used in Int8KVCache - kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) - vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) + + if env.generate_cache_stacked: + s_shape = (shape[0], shape[1], 1, shape[3], 1) + else: + s_shape = (shape[0], 1, shape[2], 1) + kscaler = jnp.ones(s_shape, dtype=jnp.bfloat16) + vscaler = jnp.ones(s_shape, dtype=jnp.bfloat16) cache_k, cache_v, kscaler, vscaler = torchjax.to_torch( (cache_k, cache_v, kscaler, vscaler) @@ -205,23 +531,126 @@ def empty(cls, shape, device, bf16_enable, env): def quantize(self, val): """Quantize value""" # val is (batch, heads, seqlen, dim) - scale = torch.amax(val.abs(), axis=(1, 3), keepdim=True) + scale = torch.amax(val.abs(), axis=(-3, -1), keepdim=True) scale = scale / 127 return (val / scale).to(torch.int8), scale - def update(self, xk, xv): + def update(self, xk, xv, layer_id: int): """Update kv cache""" k_quant, kscale = self.quantize(xk) v_quant, vscale = self.quantize(xv) - if self.env.ring_buffer: + + if self.env.lazy_cache_update: + if self.env.new_cache_stacked: + self.new_ks[layer_id, ...] = k_quant + self.new_vs[layer_id, ...] = v_quant + self.new_k_scaler[layer_id, ...] = kscale + self.new_v_scaler[layer_id, ...] = vscale + else: + if self.env.generate_cache_stacked: + self.new_ks.append(k_quant) + self.new_vs.append(v_quant) + self.new_k_scaler.append(kscale) + self.new_v_scaler.append(vscale) + else: + self.new_ks = k_quant + self.new_vs = v_quant + self.new_k_scaler = kscale + self.new_v_scaler = vscale + elif self.env.ring_buffer: self.cache_k[:, :, self.input_pos, :] = k_quant self.cache_v[:, :, self.input_pos, :] = v_quant self.k_scaler[:, :, self.input_pos, :] = kscale self.v_scaler[:, :, self.input_pos, :] = vscale else: - batch = jnp.arange(self.env.batch_size) - self.cache_k[batch, :, self.input_pos, :] = k_quant.squeeze(2) - self.cache_v[batch, :, self.input_pos, :] = v_quant.squeeze(2) - self.k_scaler[batch, :, self.input_pos, :] = kscale.squeeze(2) - self.v_scaler[batch, :, self.input_pos, :] = vscale.squeeze(2) - return self.cache_k, self.cache_v, self.k_scaler, self.v_scaler + # We don't handle left aligned but lazy_cache_update=False + self.cache_k[self.batch, :, self.input_pos, :] = k_quant.squeeze(2) + self.cache_v[self.batch, :, self.input_pos, :] = v_quant.squeeze(2) + self.k_scaler[self.batch, :, self.input_pos, :] = kscale.squeeze(2) + self.v_scaler[self.batch, :, self.input_pos, :] = vscale.squeeze(2) + + return ( + self.cache_k, + self.cache_v, + k_quant, + v_quant, + self.k_scaler, + self.v_scaler, + kscale, + vscale, + ) + + def finalize(self): + """Finalize the cache operation and updates the cache.""" + if not self.env.lazy_cache_update: + return + if self.env.ring_buffer: + # Assume no cache stack for ring buffer + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax()) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax()) + ) + else: + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + # new kv scaler also has to go through shard_map instead of indexing + # because it needs to reshape to (batch, layer) which mess up with the data + caches = [ + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, + self.new_k_scaler, + self.new_v_scaler, + ] + ( + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, *caches, self.input_pos + ) + else: + caches = [ + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, + self.new_k_scaler, + self.new_v_scaler, + ] + ( + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, *caches, self.input_pos + ) + else: + ( + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, + self.new_k_scaler, + self.new_v_scaler, + self.input_pos, + ) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 78f8da9f..70b530fc 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -90,6 +90,31 @@ "Whether to enable ring buffer", required=False, ) +flags.DEFINE_bool( + "flash_attention", + False, + "Whether to enable flas attention. Only takes effect at test mode", + required=False, +) +flags.DEFINE_bool( + "generate_cache_stacked", + False, + "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", + required=False, +) +flags.DEFINE_bool( + "new_cache_stacked", + False, + "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", + required=False, +) +flags.DEFINE_bool( + "lazy_cache_update", + False, + "Whether to update the cache during attention or delayed until all the layers are done. " + "Only takes effect at test mode", + required=False, +) flags.DEFINE_float( "temperature", 1.0, @@ -184,6 +209,10 @@ def create_engine_from_config_flags(): nucleus_topp=FLAGS.nucleus_topp, topk=FLAGS.topk, ring_buffer=FLAGS.ring_buffer, + flash_attention=FLAGS.flash_attention, + generate_cache_stacked=FLAGS.generate_cache_stacked, + new_cache_stacked=FLAGS.new_cache_stacked, + lazy_cache_update=FLAGS.lazy_cache_update, ) print("Initialize engine", time.perf_counter() - start) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 0b27db3d..79dfb945 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -145,7 +145,7 @@ def init_decode_state( (self.env.batch_size, self.env.cache_sequence_length), float("-inf"), dtype=self.default_dtype, - ), + ), # mask ) # pylint: disable-next=all @@ -195,7 +195,10 @@ def _call_model_generate( # The mode is needed so that tensors created inside of # the model (such as via torch.ones etc) also have the right type res = torch.func.functional_call(self.pt_model, paramst, argst) - updated_caches = [c.state() for c in caches_obj] + updated_caches = [] + for c in caches_obj: + c.finalize() + updated_caches.append(c.state()) scales = [] if self.env.quant_config.enable_kv_quantization: scales = [c.scalers() for c in caches_obj] @@ -218,7 +221,8 @@ def _call_model_prefill(self, weights, tokens, input_indexes): dtype=self.default_dtype, ) mask = jnp.triu(mask, k=1) - args = (tokens, input_indexes, caches, mask) + start = jnp.zeros((tokens.shape[0],), dtype=jnp.int32) + args = (tokens, input_indexes, caches, mask, start) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: @@ -328,7 +332,7 @@ def _insert_no_wrap( tokens = decode_state.tokens.at[slot].set(prefix.token) x = jnp.arange(0, self.env.cache_sequence_length) - cond = jnp.logical_and(x <= current_pos, x >= pos) + cond = jnp.logical_and(x < current_pos, x >= pos) mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) start = decode_state.start.at[slot].set( @@ -338,31 +342,48 @@ def _insert_no_wrap( if not self.env.quant_config.enable_kv_quantization: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) - def insert(cache, new_entry): + def insert(cache, new_entry, update_index): res = jax.lax.dynamic_update_slice( cache, new_entry, - [slot, 0, pos, 0], + update_index, ) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) return res - caches = [ - (insert(k, newk), insert(v, newv)) - for (k, v), (newk, newv) in zip(decode_state.caches, prefix.caches) - ] + if self.env.generate_cache_stacked: + caches = decode_state.caches + for idx, (newk, newv) in enumerate(prefix.caches): + update_index = [idx, slot, 0, pos, 0] + newk = jnp.expand_dims(newk, 0) + newv = jnp.expand_dims(newv, 0) + caches = [ + ( + insert(caches[0][0], newk, update_index), + insert(caches[0][1], newv, update_index), + ) + ] + else: + update_index = [slot, 0, pos, 0] + caches = [ + (insert(k, newk, update_index), insert(v, newv, update_index)) + for (k, v), (newk, newv) in zip(decode_state.caches, prefix.caches) + ] else: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) - def insert(cache, scaler, new_entry): - reduce_axis = (1, 3) + def insert(cache, scaler, new_entry, update_index): + reduce_axis = (-3, -1) vals, scales, _ = torchjax.call_torch( quantize.quantize_tensor, new_entry, reduce_axis ) + if self.env.generate_cache_stacked: + vals = jnp.expand_dims(vals, 0) + scales = jnp.expand_dims(scales, 0) new_scaler = jax.lax.dynamic_update_slice( scaler, scales, - [slot, 0, pos, 0], + update_index, ) new_scaler = jax.lax.with_sharding_constraint( new_scaler, self.replicated @@ -370,19 +391,37 @@ def insert(cache, scaler, new_entry): res = jax.lax.dynamic_update_slice( cache, vals, - [slot, 0, pos, 0], + update_index, ) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) return res, new_scaler - for (k, v), (kscaler, vscaler), (newk, newv) in zip( - decode_state.caches, decode_state.cache_scales, prefix.caches - ): - kcache, kscale = insert(k, kscaler, newk) - vcache, vscale = insert(v, vscaler, newv) - caches.append((kcache, vcache)) - scales.append((kscale, vscale)) - + if self.env.generate_cache_stacked: + cache_k, k_scale = ( + decode_state.caches[0][0], + decode_state.cache_scales[0][0], + ) + cache_v, v_scale = ( + decode_state.caches[0][1], + decode_state.cache_scales[0][1], + ) + for idx, (newk, newv) in enumerate(prefix.caches): + update_index = [idx, slot, 0, pos, 0] + # newk = jnp.expand_dims(newk, 0) + # newv = jnp.expand_dims(newv, 0) + cache_k, k_scale = insert(cache_k, k_scale, newk, update_index) + cache_v, v_scale = insert(cache_v, v_scale, newv, update_index) + caches = [(cache_k, cache_v)] + scales = [(k_scale, v_scale)] + else: + update_index = [slot, 0, pos, 0] + for (k, v), (kscaler, vscaler), (newk, newv) in zip( + decode_state.caches, decode_state.cache_scales, prefix.caches + ): + kcache, kscale = insert(k, kscaler, newk, update_index) + vcache, vscale = insert(v, vscaler, newv, update_index) + caches.append((kcache, vcache)) + scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) return DecodeState( tokens, @@ -416,10 +455,10 @@ def _insert_wrap( cond = jax.lax.cond( decode_state.current_position > start_insert, lambda x, start_insert, current_position: jnp.logical_and( - x >= start_insert, x <= current_position + x >= start_insert, x < current_position ), lambda x, start_insert, current_position: jnp.logical_or( - x >= start_insert, x <= current_position + x >= start_insert, x < current_position ), x, start_insert, @@ -494,11 +533,6 @@ def insert( decode_state: DecodeState, slot: int, ) -> DecodeState: - # logging.info( - # 'Jet input prefix: %s, decode state before insert: %s', - # prefix, - # decode_state, - # ) if self.env.ring_buffer: start_insert = decode_state.current_position - prefix.seq_len end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen @@ -580,11 +614,9 @@ def generate( pos = decode_state.current_position if self.env.ring_buffer: input_indexes = jnp.full((1,), pos) - mask = decode_state.mask.at[:, decode_state.current_position].set(0) else: input_indexes = decode_state.input_pos - batch = jnp.arange(self.env.batch_size) - mask = decode_state.mask.at[batch, decode_state.input_pos].set(0) + ragged_batch_index, ragged_block_index = ( self.precompute_ragged_block_indices(decode_state) ) @@ -592,6 +624,16 @@ def generate( (-1) ), ragged_block_index.reshape((-1)) + def update_mask(): + if self.env.ring_buffer: + return decode_state.mask.at[:, decode_state.current_position].set(0) + + batch = jnp.arange(self.env.batch_size) + return decode_state.mask.at[batch, decode_state.input_pos].set(0) + + mask = decode_state.mask + if not self.env.lazy_cache_update: + mask = update_mask() logits, new_caches, new_scales = self._call_model_generate( params, decode_state.tokens, @@ -605,6 +647,10 @@ def generate( ragged_block_index, ) + if self.env.lazy_cache_update: + # fill mask later, now use flash attention + mask = update_mask() + next_token = self._sampling(logits, self.env.batch_size) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 @@ -648,11 +694,6 @@ def generate( input_pos, mask, ) - print( - "new_pos", - (decode_state.current_position + 1) % self.env.cache_sequence_length, - ) - print(f"new_token: {jnp.squeeze(next_token)}") return new_decode_state, result_tokens # pylint: disable-next=all @@ -832,6 +873,10 @@ def create_pytorch_engine( nucleus_topp=None, topk=None, ring_buffer=True, + flash_attention=False, + generate_cache_stacked=False, + new_cache_stacked=False, + lazy_cache_update=False, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -902,6 +947,10 @@ def create_pytorch_engine( nucleus_topp=nucleus_topp, topk=topk, ring_buffer=ring_buffer, + flash_attention=flash_attention, + generate_cache_stacked=generate_cache_stacked, + new_cache_stacked=new_cache_stacked, + lazy_cache_update=lazy_cache_update, ) if shard_on_batch and sharding_config: diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index fb1b99ba..84289d90 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -16,6 +16,7 @@ from typing import Tuple import jax +import jax.numpy as jnp import jax.sharding as jsharding from jax.experimental import mesh_utils import torch_xla2 @@ -91,11 +92,18 @@ class JetEngineEnvironmentData: block_size: int = 512 # Starting position - starting_position: int = 512 + starting_position: int = 0 # Ring buffer ring_buffer: bool = True + flash_attention: bool = False + + generate_cache_stacked: bool = False + + new_cache_stacked: bool = False + + lazy_cache_update: bool = False # Variables used in token sampling # sampling algorithm to use ("greedy", "weighted", "neucleus", "topk") sampling_algorithm: str = "greedy" @@ -109,6 +117,10 @@ class JetEngineEnvironmentData: # temperature parameter for scaling probability temperature: float = 1.0 + testing: bool = False + + testing_seed: int = 0 + # pylint: disable-next=all class JetEngineEnvironment: @@ -119,10 +131,34 @@ def __init__(self, data: JetEngineEnvironmentData): self.batch_size = self._data.batch_size self.seq_len = self._data.max_input_sequence_length self.cache_len = self._data.cache_sequence_length - self.ragged_mha = self._data.ragged_mha self.block_size = self._data.block_size self.starting_position = self._data.starting_position + self.num_layers = self._data.num_layers + self.testing = self._data.testing + self.testing_seed = self._data.testing_seed self.ring_buffer = self._data.ring_buffer + + if not self.ring_buffer: + self.lazy_cache_update = True + self.ragged_mha = True + self.flash_attention = True + self.generate_cache_stacked = True + self.new_cache_stacked = True + + if self.testing: + self.lazy_cache_update = self._data.lazy_cache_update + self.ragged_mha = self._data.ragged_mha + self.flash_attention = self._data.flash_attention + self.generate_cache_stacked = self._data.generate_cache_stacked + self.new_cache_stacked = self._data.new_cache_stacked + + self.default_type = jnp.bfloat16 if self._data.bf16_enable else jnp.float32 + + if self.generate_cache_stacked: + self.cache_shape = (self.num_layers, *self._data.cache_shape) + else: + self.cache_shape = self._data.cache_shape + P = jax.sharding.PartitionSpec num_of_partitions = jax.device_count() @@ -136,19 +172,29 @@ def __init__(self, data: JetEngineEnvironmentData): self.x_sharding = jsharding.NamedSharding(self.mesh, P("x")) self.replicated = jsharding.NamedSharding(self.mesh, P()) + if self.generate_cache_stacked: + self.attention_kv_axis_names = ( + "layer", + "batch", + "num_attn_heads", + "sequence_length", + "head_dim", + ) if data.shard_on_batch: - cache_sharding_axis = 0 + self.kv_cache_shard_axis = "batch" else: - cache_sharding_axis = self.attention_kv_axis_names.index( - self.kv_cache_shard_axis - ) + self.kv_cache_shard_axis = "num_attn_heads" - if self.cache_shape[cache_sharding_axis] == 1: + self.cache_sharding_axis = self.attention_kv_axis_names.index( + self.kv_cache_shard_axis + ) + + if self.cache_shape[self.cache_sharding_axis] == 1: # cannot shard on an axis that is 1 # default to last - cache_sharding_axis = len(self.cache_shape) - 1 + self.cache_sharding_axis = len(self.cache_shape) - 1 - self.cache_sharding = self.sharding_by_axis(cache_sharding_axis) + self.cache_sharding = self.sharding_by_axis(self.cache_sharding_axis) self._load_sharding_config() def _load_sharding_config(self): @@ -194,19 +240,20 @@ def make_caches_prefill(self): def make_caches_generate(self): """Create kv caches for inference generation""" caches = [] - shape = self._data.cache_shape - for _ in range(self.num_layers): + layered_cache_count = 1 if self.generate_cache_stacked else self.num_layers + + for _ in range(layered_cache_count): if self._data.quant_config.enable_kv_quantization: caches.append( cache_manager.Int8KVCacheGenerate.empty( - shape, self.cache_sharding, self.bf16_enable, env=self + self.cache_shape, self.cache_sharding, env=self ) ) else: caches.append( cache_manager.KVCacheGenerate.empty( - shape, self.cache_sharding, self.bf16_enable, env=self + self.cache_shape, self.cache_sharding, env=self ) ) return caches diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index e2756d73..d484df93 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -368,9 +368,19 @@ def apply_rotary_emb( def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep).""" - bs, n_kv_heads, slen, head_dim = x.shape + *_, bs, n_kv_heads, slen, head_dim = x.shape + stacked = x.ndim == 5 + if n_rep == 1: return x + + if stacked: + layer = x.shape[0] + return ( + x[:, :, :, None, :, :] + .expand(layer, bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(layer, bs, n_kv_heads * n_rep, slen, head_dim) + ) return ( x[:, :, None, :, :] .expand(bs, n_kv_heads, n_rep, slen, head_dim) @@ -380,18 +390,36 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class AttentionKernel: - def __init__(self, env): + def __init__(self, env, layer_id): self.env = env - self.shard_axis = 0 if self.env.shard_on_batch else 1 - qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + self.q_shard_axis = 0 if self.env.shard_on_batch else 1 + self.kv_shard_axis = ( + 0 + if self.env.shard_on_batch + else 2 + if self.env.generate_cache_stacked + else 1 + ) + q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads + kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention - self.ragged_attention = ak.RaggedAttentionKernel( + self.flash_attention = ak.flash_attention + self.ragged_attention_orig = ak.RaggedAttentionKernel( + env, + input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.kv_shard_axis, + ) + self.ragged_attention_new = ak.RaggedAttentionKernel( env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), - output_specs=(qkv_pspec, (others_pspec, others_pspec)), - sharding_axis=self.shard_axis, + input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.q_shard_axis, ) + self.layer_id = layer_id def __call__( self, @@ -414,53 +442,142 @@ def __call__( cache: CacheManagerInterface object """ bsz, num_heads, seqlen, head_dim = xq.shape - _, num_kv_heads, _, kv_head_dim = xk.shape + num_kv_heads = xk.shape[-3] + kv_head_dim = xk.shape[-1] n_rep = num_heads // num_kv_heads - if not self.env.ragged_mha and seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - with jax.named_scope("attn_insert_cache"): - keys, values = cache.update(xk, xv) - keys = repeat_kv(keys, n_rep) - values = repeat_kv(values, n_rep) + def attend(xq, keys, values, local_mask=None): + if keys.ndim == 4: + impl = self.ragged_attention_new + else: + impl = self.ragged_attention_orig + + true_len = seqlen + # When GQA is enabled, it not necessary to expand + if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1: + true_len = 2 + # xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) + xq = torch.nn.functional.pad( + xq, (0, 0, 0, true_len - seqlen), "constant", 0 + ) - with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax( - self.ragged_attention, + local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( + impl, xq, keys, values, + self.layer_id, start, end, ragged_batch_index, ragged_block_index, ) + elif self.env.flash_attention: + with torch_xla2.default_env(): + local_output, (local_max, local_denom) = self.flash_attention( + xq, keys, values, self.layer_id, mask=local_mask + ) else: - output = self.dense_attention(xq, keys, values, None, None, mask) + local_output = self.dense_attention( + xq, keys, values, None, None, local_mask + ) + local_max = None + local_denom = None + + local_output = local_output.reshape(bsz, num_heads, true_len, head_dim) + if local_max is not None: + local_max = local_max.reshape(bsz, num_heads, true_len, 1) + local_denom = local_denom.reshape(bsz, num_heads, true_len, 1) + + if true_len != seqlen: + local_output = local_output[:, :, 0:seqlen, :] + if local_max is not None: + local_max = local_max[:, :, 0:seqlen, :] + if local_denom is not None: + local_denom = local_denom[:, :, 0:seqlen, :] + + # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") + # if local_max is not None and local_denom is not None: + # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") + self.env.apply_sharding(local_output, axis=self.q_shard_axis) + return local_output, (local_max, local_denom) + + with jax.named_scope("attn_insert_cache"): + orig_keys, orig_values = cache.update(xk, xv, self.layer_id) + # We are not using ragged attention for prefill yet. + if not self.env.ragged_mha or seqlen > 1: + orig_keys = repeat_kv(orig_keys, n_rep) + orig_values = repeat_kv(orig_values, n_rep) + + # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") + with jax.named_scope("attn_qkv"): + existing_output, (existing_max, existing_denom) = attend( + xq, orig_keys, orig_values, mask + ) + # Updating cache during each step still has very large impact on latency. + # For non flash attention or prefill, existing output contains everything + if not self.env.lazy_cache_update or seqlen > 1: + return existing_output + + # For flash attention, existing output contains the existing kv cache generated logits + with jax.named_scope("attn_new_qkv"): + if not self.env.ragged_mha or seqlen > 1: + xk = repeat_kv(xk, n_rep) + xv = repeat_kv(xv, n_rep) + new_output, (new_max, new_denom) = attend(xq, xk, xv, None) + + with jax.named_scope("attn_global"): + # print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") + # print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") + + global_sum = existing_denom * torch.exp( + existing_max + ) + new_denom * torch.exp(new_max) + existing_output = ( + existing_output + * existing_denom + * torch.exp(existing_max) + / global_sum + ) + new_output = new_output * new_denom * torch.exp(new_max) / global_sum + attn_out = existing_output + new_output - if not self.env.ragged_mha and seqlen == 1: - output = output[:, :, 0:1, :] - # For XLA matmul performance boost - # output = torch.matmul(scores, values) - self.env.apply_sharding(output, axis=self.shard_axis) - return output + return attn_out class Int8KVAttentionKernel: - def __init__(self, env): + def __init__(self, env, layer_id): self.env = env - self.shard_axis = 0 if self.env.shard_on_batch else 1 - qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + self.q_shard_axis = 0 if self.env.shard_on_batch else 1 + self.kv_shard_axis = ( + 0 + if self.env.shard_on_batch + else 2 + if self.env.generate_cache_stacked + else 1 + ) + q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads + kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention - self.ragged_attention = ak.RaggedAttentionKernel( + self.flash_attention = ak.flash_attention + self.ragged_attention_orig = ak.RaggedAttentionKernel( + env, + input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.kv_shard_axis, + ) + self.ragged_attention_new = ak.RaggedAttentionKernel( env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), - output_specs=(qkv_pspec, (others_pspec, others_pspec)), - sharding_axis=self.shard_axis, + input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.q_shard_axis, ) + self.layer_id = layer_id def __call__( self, @@ -483,24 +600,32 @@ def __call__( cache: CacheManagerInterface object """ bsz, num_heads, seqlen, head_dim = xq.shape - _, num_kv_heads, _, kv_head_dim = xk.shape + num_kv_heads = xk.shape[-3] + kv_head_dim = xk.shape[-1] n_rep = num_heads // num_kv_heads - if not self.env.ragged_mha and seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - - with jax.named_scope("attn_insert_cache"): - keys, values, k_scaler, v_scaler = cache.update(xk, xv) - keys = repeat_kv(keys, n_rep) - values = repeat_kv(values, n_rep) + def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): + if keys.ndim == 4: + impl = self.ragged_attention_new + else: + impl = self.ragged_attention_orig + + true_len = seqlen + # When GQA is enabled, it not necessary to expand + if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1: + true_len = 2 + xq = torch.nn.functional.pad( + xq, (0, 0, 0, true_len - seqlen), "constant", 0 + ) - with jax.named_scope("attn_qkv"): + # We are not using ragged attention for prefill yet. if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax( - self.ragged_attention, + local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( + impl, xq, keys, values, + self.layer_id, start, end, ragged_batch_index, @@ -508,22 +633,94 @@ def __call__( k_scaler, v_scaler, ) + elif self.env.flash_attention: + with torch_xla2.default_env(): + local_output, (local_max, local_denom) = self.flash_attention( + xq, + keys, + values, + self.layer_id, + k_scaler, + v_scaler, + mask=local_mask, + ) else: - output = self.dense_attention( - xq, keys, values, k_scaler, v_scaler, mask + local_output = self.dense_attention( + xq, keys, values, k_scaler, v_scaler, local_mask ) + local_max = None + local_denom = None - if not self.env.ragged_mha and seqlen == 1: - output = output[:, :, 0:1, :] + local_output = local_output.reshape(bsz, num_heads, true_len, head_dim) + if local_max is not None: + local_max = local_max.reshape(bsz, num_heads, true_len, 1) + local_denom = local_denom.reshape(bsz, num_heads, true_len, 1) - self.env.apply_sharding(output, axis=self.shard_axis) - return output + if true_len != seqlen: + local_output = local_output[:, :, 0:seqlen, :] + if local_max is not None: + local_max = local_max[:, :, 0:seqlen, :] + local_denom = local_denom[:, :, 0:seqlen, :] + + self.env.apply_sharding(local_output, axis=self.q_shard_axis) + return local_output, (local_max, local_denom) + + with jax.named_scope("attn_insert_cache"): + ( + orig_keys, + orig_values, + new_key, + new_value, + k_scaler, + v_scaler, + new_k_scaler, + new_v_scaler, + ) = cache.update(xk, xv, self.layer_id) + # We are not using ragged attention for prefill yet. + if not self.env.ragged_mha or seqlen > 1: + orig_keys = repeat_kv(orig_keys, n_rep) + orig_values = repeat_kv(orig_values, n_rep) + with jax.named_scope("attn_qkv"): + existing_output, (existing_max, existing_denom) = attend( + xq, orig_keys, orig_values, k_scaler, v_scaler, mask + ) + + # For non flash attention or prefill, existing output contains everything + if not self.env.lazy_cache_update or seqlen > 1: + return existing_output + + # For flash attention, existing output contains the existing kv cache generated logits + with jax.named_scope("attn_new_qkv"): + # At this point, flash attention or ragged attention must have been enabled + if not self.env.ragged_mha or seqlen > 1: + new_key = repeat_kv(new_key, n_rep) + new_value = repeat_kv(new_value, n_rep) + new_output, (new_max, new_denom) = attend( + xq, new_key, new_value, new_k_scaler, new_v_scaler, None + ) + + with jax.named_scope("attn_global"): + global_sum = existing_denom * torch.exp( + existing_max + ) + new_denom * torch.exp(new_max) + existing_output = ( + existing_output + * existing_denom + * torch.exp(existing_max) + / global_sum + ) + new_output = new_output * new_denom * torch.exp(new_max) / global_sum + attn_out = existing_output + new_output + + return attn_out class Attention(ModuleBase): """Attention module.""" - def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): + def __init__( + self, n_heads, n_kv_heads, head_dim, hidden_size, device, env, layer_id + ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads @@ -531,6 +728,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): self.n_rep = self.n_heads // self.n_kv_heads self.env = env self.hidden_size = hidden_size + self.layer_id = layer_id LinearLayer = get_quantized_linear_layer(env.quant_config) linear_kwargs = {} @@ -550,7 +748,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): if env.quant_config.enable_kv_quantization else AttentionKernel ) - self.attention_kernel = Kernel(env) + self.attention_kernel = Kernel(env, self.layer_id) self.q_size = n_heads * self.head_dim self.kv_size = self.n_kv_heads * self.head_dim @@ -630,16 +828,26 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) + if mask.ndim == 2: + if seqlen == 1: + mask = mask[:, None, None, :] + else: + mask = mask[None, None, :, :] + + # if cache is not None and cache.cache_k is not None: + # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( xq, xk, xv, mask, + # cache[self.layer_id], cache, start, end, ragged_batch_index, ragged_block_index, ).type_as(xq) + # print(f"output {output.shape}") output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) return self.wo(output) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 01b647db..1a4f15e7 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -466,6 +466,9 @@ def prefill_ray( logits = logits[0] token = np.argmax(logits[true_length - 1]) + updated_caches = multihost_utils.process_allgather( + updated_caches, tiled=True + ) prefix = Prefix(token, updated_caches, true_length) self.prefix_queue.put(prefix, block=False) @@ -596,7 +599,7 @@ def insert(cache, new_entry): @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) def insert(cache, scaler, new_entry): - reduce_axis = (1, 3) + reduce_axis = (-3, -1) vals, scales, _ = torchjax.call_torch( quantize.quantize_tensor, new_entry, reduce_axis ) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 1072dad9..5773b8bd 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -73,6 +73,7 @@ def __init__( head_dim: int, device, env, + layer_id, ): super().__init__() @@ -135,7 +136,7 @@ def __init__( if env.quant_config.enable_kv_quantization else layers.AttentionKernel ) - self.attention_kernel = Kernel(env) + self.attention_kernel = Kernel(env, layer_id) def forward( self, @@ -272,7 +273,7 @@ def forward(self, x): class GemmaDecoderLayer(nn.Module): - def __init__(self, config: gemma_config.GemmaConfig, env): + def __init__(self, config: gemma_config.GemmaConfig, env, layer_id): super().__init__() self.self_attn = GemmaAttention( config.hidden_size, @@ -281,6 +282,7 @@ def __init__(self, config: gemma_config.GemmaConfig, env): config.head_dim, config.device, env, + layer_id, ) self.mlp = GemmaMLP( @@ -340,8 +342,8 @@ def __init__(self, config: gemma_config.GemmaConfig, env): self.env = env self.layers = nn.ModuleList() - for _ in range(config.num_hidden_layers): - self.layers.append(GemmaDecoderLayer(config, env)) + for layer_id in range(config.num_hidden_layers): + self.layers.append(GemmaDecoderLayer(config, env, layer_id)) self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, device=config.device ) diff --git a/jetstream_pt/third_party/llama/model_args.py b/jetstream_pt/third_party/llama/model_args.py index 7956667d..1b72c0a7 100755 --- a/jetstream_pt/third_party/llama/model_args.py +++ b/jetstream_pt/third_party/llama/model_args.py @@ -45,7 +45,8 @@ def get_arg( "dim": 128, "vocab_size": 32000, "multiple_of": 32, - "n_heads": 8, + "n_heads": 64, + "n_kv_heads": 8, "n_layers": 3, "norm_eps": 1e-05, } diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 0752cc45..0a986f2a 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -103,6 +103,7 @@ def __init__( args.dim, env=env, device=args.device, + layer_id=layer_id, ) self.feed_forward = FeedForward( dim=args.dim, @@ -249,7 +250,6 @@ def forward( ragged_batch_index: precomputed batch index for ragged attention ragged_block_index: precomputed block index for ragged attention """ - with jax.named_scope("transformer_tok"): seqlen = tokens.shape[-1] h = self.tok_embeddings(tokens) @@ -259,12 +259,16 @@ def forward( freqs_cis = self.freqs_cis[input_pos] freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) - assert len(caches) == len( - self.layers - ), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match" end = None if start is None else (start + input_pos) % self.env.cache_len - for layer, cache in zip(self.layers, caches): - with jax.named_scope("TransformerBlock_Layer_" + str(layer.layer_id)): + # For stacked case, cannot get cache inside the loop which will cause cache copy + for layer_id, layer in enumerate(self.layers): + if caches[0].stacked: + cache = caches[0] + else: + cache = caches[layer_id] + # else: # For stacked case, there is only 1 yer of kv cache + + with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): h = layer( h, freqs_cis, diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py index 7d053703..89a66378 100644 --- a/jetstream_pt/third_party/mixtral/model.py +++ b/jetstream_pt/third_party/mixtral/model.py @@ -38,7 +38,8 @@ def __init__(self, config: ModelArgs, env) -> None: config.vocab_size, config.dim, device=config.device ) self.layers = nn.ModuleList( - TransformerBlock(config, env) for _ in range(config.n_layer) + TransformerBlock(config, env, layer_id) + for layer_id in range(config.n_layer) ) self.norm = RMSNorm(config.dim, eps=config.norm_eps) LinearLayer = get_quantized_linear_layer(env.quant_config) @@ -142,7 +143,7 @@ def get_weight_sharding_type(): class TransformerBlock(nn.Module): - def __init__(self, config: ModelArgs, env) -> None: + def __init__(self, config: ModelArgs, env, layer_id) -> None: super().__init__() self.attention = Attention( config.n_head, @@ -151,6 +152,7 @@ def __init__(self, config: ModelArgs, env) -> None: config.dim, env=env, device=config.device, + layer_id=layer_id, ) self.block_sparse_moe = MOEFeedForward(config, config.device, env) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) diff --git a/run_interactive.py b/run_interactive.py index eef2def8..8463658c 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -18,13 +18,10 @@ from typing import List # import torch_xla2 first! -import torch_xla2 # pylint: disable import jax import numpy as np -from absl import app, flags -from colorama import Fore, Style +from absl import app from jetstream.engine import token_utils -from jetstream_pt import engine as je from jetstream_pt.config import FLAGS, create_engine_from_config_flags @@ -42,16 +39,30 @@ def main(argv): max_output_length = 1024 profiling_output = FLAGS.profiling_output - profiling_prefill = FLAGS.profiling_prefill - if profiling_output and profiling_prefill: + profiling_prefill = ( + FLAGS.profiling_prefill + and profiling_output is not None + and profiling_output != "" + ) + + if profiling_prefill: jax.profiler.start_trace(profiling_output) decode_state = engine.init_decode_state() + + if profiling_prefill: + jax.profiler.stop_trace() + prompts: List[str] = [ + # pylint: disable-next=all "I believe the meaning of life is", + # pylint: disable-next=all "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", + # pylint: disable-next=all "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", + # pylint: disable-next=all "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", + # pylint: disable-next=all "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", ] for prompt in prompts: @@ -62,21 +73,31 @@ def main(argv): print(f"---- Encoded tokens are: {tokens}") # pylint: disable-next=all + if profiling_prefill: + jax.profiler.start_trace(profiling_output) + prefill_result, _ = engine.prefill( params=params, padded_tokens=tokens, true_length=true_length ) # pylint: disable-next=all decode_state = engine.insert(prefill_result, decode_state, slot=slot) + + if profiling_prefill: + jax.profiler.stop_trace() + sampled_tokens_list = [] print(f"---- Streaming decode started on #slot{slot}.") complete = np.zeros((1,), dtype=np.bool_) while True: - if profiling_output and not profiling_prefill: + if profiling_output: jax.profiler.start_trace(profiling_output) + decode_state, result_tokens = engine.generate(params, decode_state) - if profiling_output and not profiling_prefill: - jax.profiler.stop_trace() result_tokens = result_tokens.convert_to_numpy() + + if profiling_output: + jax.profiler.stop_trace() + output, complete = token_utils.process_result_tokens( tokenizer=tokenizer, slot=slot, @@ -94,9 +115,6 @@ def main(argv): print("---- All output text.") print(tokenizer.decode(sampled_tokens_list)) - if profiling_output and profiling_prefill: - jax.profiler.stop_trace() - if __name__ == "__main__": os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" diff --git a/run_interactive_disaggregated.py b/run_interactive_disaggregated.py index 0d11796e..b6ffb43c 100644 --- a/run_interactive_disaggregated.py +++ b/run_interactive_disaggregated.py @@ -19,9 +19,7 @@ from typing import List from absl import app from absl import flags -from colorama import Fore, Style -import numpy as np import jax from jetstream.engine import token_utils @@ -129,7 +127,6 @@ def main(argv): print("Load params ", time.perf_counter() - start) metadata = prefill_engine.get_tokenizer() - tokenizer = prefill_engine.build_tokenizer(metadata) vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) stop_tokens = [vocab.eos_id, vocab.pad_id] max_output_length = 1024 @@ -157,19 +154,21 @@ def main(argv): print(f"---- Input prompts are: {prompt}") print(f"---- Encoded tokens are: {tokens}") - # pylint: disable-next=all print( + # pylint: disable-next=all f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}" ) prefill_result, _ = prefill_engine.prefill( params=None, padded_tokens=tokens, true_length=true_length ) print( + # pylint: disable-next=all f"---- Transfer prefill result to decode engine pod_slice_name: {decode_engine.pod_slice_name}" ) decode_engine.transfer(prefill_result) - # pylint: disable-next=all + print( + # pylint: disable-next=all f"---- Do insert in decode engine pod_slice_name: {decode_engine.pod_slice_name}" ) decode_state = decode_engine.insert(prefill_result, None, slot=slot) diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 24b27987..f9307126 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -19,7 +19,6 @@ import jax from absl import app, flags -from colorama import Fore, Style from jetstream.engine import token_utils from jetstream_pt import ray_engine from jetstream_pt.config import FLAGS @@ -57,7 +56,7 @@ def create_engine(): sharding_config=FLAGS.sharding_config, num_hosts=_NUM_HOSTS.value, worker_chips=_WORKER_CHIPS.value, - tpu_chips=_TPU_CHIPS, + tpu_chips=_TPU_CHIPS.value, ) print("Initialize engine", time.perf_counter() - start) diff --git a/run_ray_serve_interleave.py b/run_ray_serve_interleave.py index 6d4edb5d..853ce068 100644 --- a/run_ray_serve_interleave.py +++ b/run_ray_serve_interleave.py @@ -40,7 +40,11 @@ def create_head_resource_name(generation, tpu_chips): - return f"TPU-{generation}-{tpu_chips}-head" + if generation == "v5litepod": + return f"TPU-{generation}-{tpu_chips}-head" + else: + tpu_cores = tpu_chips * 2 + return f"TPU-{generation}-{tpu_cores}-head" def create_engine(**kwargs): @@ -73,6 +77,7 @@ def create_engine(**kwargs): @serve.deployment class JetStreamDeployment: + """JetStream deployment.""" def __init__(self, **kwargs): os.environ["XLA_FLAGS"] = ( @@ -111,18 +116,24 @@ def __init__(self, **kwargs): print("Started jetstream driver....") + # pylint: disable-next=all async def Decode( - self, request: jetstream_pb2.DecodeRequest + self, + # pylint: disable-next=all + request: jetstream_pb2.DecodeRequest, + # pylint: disable-next=all ) -> AsyncIterator[jetstream_pb2.DecodeResponse]: - + """Async decode function.""" return self.orchestrator.Decode(request) def main(_argv): + """Main function""" resource_name = create_head_resource_name( FLAGS.tpu_generation, FLAGS.tpu_chips ) print(f"Using head resource {resource_name}") + # pylint: disable-next=all deployment = JetStreamDeployment.options( ray_actor_options={"resources": {resource_name: 1}} ).bind( diff --git a/run_server.py b/run_server.py index be5933ec..6f9fbed8 100644 --- a/run_server.py +++ b/run_server.py @@ -17,7 +17,6 @@ from typing import Sequence # import torch_xla2 first! -import torch_xla2 # pylint: disable import jax from absl import app, flags from jetstream.core import server_lib diff --git a/run_server_with_ray.py b/run_server_with_ray.py index 03489e1a..97592804 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -19,7 +19,6 @@ from absl import app, flags # import torch_xla2 first! -import torch_xla2 # pylint: disable import jax from jetstream.core import server_lib from jetstream.core.config_lib import ServerConfig diff --git a/tests/helpers.py b/tests/helpers.py index 00442517..3c5cb4ec 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,7 +6,8 @@ from jetstream_pt import environment -def make_env_tiny(bf16_enable=True): +# pylint: disable-next=all +def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) @@ -26,11 +27,14 @@ def make_env_tiny(bf16_enable=True): environment_data.cache_sequence_length, config.dim // config.n_heads, ) + environment_data.testing = True + env_data_update_fn(environment_data) env = environment.JetEngineEnvironment(environment_data) env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu return env, config +# pylint: disable-next=all def make_mixtral_env(bf16_enable=True): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) @@ -55,14 +59,16 @@ def make_mixtral_env(bf16_enable=True): return env, config +# pylint: disable-next=all def to_xla_tensor(tree): return torch_xla2.default_env().to_xla(tree) +# pylint: disable-next=all def call_xla_model(model, weights, args): with jax.default_device(jax.devices("cpu")[0]): xla_weights, xla_inputs = to_xla_tensor((weights, args)) with torch_xla2.default_env(): result = torch.func.functional_call(model, xla_weights, xla_inputs) - result_torch = torch_xla2.tensor.j2t(result._elem) + result_torch = torch_xla2.tensor.j2t(result.jax()) return result_torch diff --git a/tests/test_hf_names.py b/tests/test_hf_names.py index c2230cde..83b76425 100644 --- a/tests/test_hf_names.py +++ b/tests/test_hf_names.py @@ -4,10 +4,13 @@ class TestModuleBase(unittest.TestCase): + """Test module base.""" def test_get_hf_names_to_real_name(self): + """Test get hugginface names to real name.""" class MyModule(ModuleBase): + """My module.""" def __init__(self): super().__init__() @@ -18,6 +21,9 @@ def __init__(self): self.param = torch.nn.Parameter(torch.randn(10)) self.hf_name("param", "model.param") + def forward(self): + """Forward function.""" + module = MyModule() expected_mapping = { "model.my_linear1.weight": "linear1.weight", @@ -30,7 +36,10 @@ def __init__(self): self.assertEqual(module.get_hf_names_to_real_name(), expected_mapping) def test_get_sharding_annotations(self): + """Test get sharding annotations.""" + class MyModule(ModuleBase): + """MyModule.""" def __init__(self): super().__init__() @@ -38,12 +47,19 @@ def __init__(self): self.embedding = torch.nn.Embedding(100, 50) self.inner = InnerModule() + def forward(self): + """Forward function.""" + class InnerModule(ModuleBase): + """Inner modeule.""" def __init__(self): super().__init__() self.fc = torch.nn.Linear(50, 100) + def forward(self): + """Forward function.""" + module = MyModule() module.annotate_sharding("linear.weight", 0) module.annotate_sharding("embedding.weight", 1) diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index dcbcf5f2..6ea6dd1b 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -22,7 +22,7 @@ import torch import torch_xla2 from torch.utils import _pytree as pytree - +from absl.testing import parameterized from jetstream_pt.engine import PyTorchEngine from jetstream_pt.third_party.llama import model_exportable, model_args @@ -31,7 +31,7 @@ from tests import helpers -class LlamaE2ETest(unittest.TestCase): +class LlamaE2ETest(parameterized.TestCase): """This test class includes all E2E test for llama2""" def _from_torch(self, tree): @@ -42,6 +42,7 @@ def _make_env(self, bf16_enable=True): torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) jax.config.update("jax_traceback_filtering", "off") + # pylint: disable-next=all config = model_args.get_model_args("tiny", 128, 1, 32000, True) environment_data = environment.JetEngineEnvironmentData() environment_data.max_input_sequence_length = 128 @@ -187,6 +188,9 @@ def _llama_e2e(self, env, model_arg): model_ours = model_exportable.Transformer(model_arg, env) + for k, v in model_ours.state_dict().items(): + if "scale" in k: + state_dict[k] = helpers.to_xla_tensor(v) engine = PyTorchEngine(pt_model=model_ours, env=env) params = self._from_torch(state_dict) @@ -233,6 +237,58 @@ def test_llama_e2e_bfloat16(self): out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertNotEqual(out_tokens, expected_output_tokens) + @parameterized.named_parameters( + ("ring_buffer_f32", True, False, False), + ("left_aligned_f32", False, False, False), + ) + def test_llama_e2e_result_verification( + self, ring_buffer, quantized, bf16_enabled + ): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.quant_config.enable_kv_quantization = quantized + + env, model_arg = helpers.make_env_tiny(bf16_enabled, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + @parameterized.named_parameters( + ("ring_buffer_int8", True, True, True), + ("ring_buffer_bf16", True, False, True), + ("left_aligned_int8", False, True, True), + ("left_aligned_bf16", False, False, True), + ) + def test_llama_e2e_no_result_verification( + self, ring_buffer, quantized, bf16_enabled + ): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.quant_config.enable_kv_quantization = quantized + + env, model_arg = helpers.make_env_tiny(bf16_enabled, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertNotEqual(out_tokens, expected_output_tokens) + # pylint: disable-next=all def test_llama_e2e_two_addtional_tokens(self): """end to end jetstream llama with addtional tokens""" diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 4d4ddfd6..efbaa09b 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -17,7 +17,6 @@ import jax.numpy as jnp import torch import torch_xla2 -from . import helpers from jetstream_pt.third_party.llama import model_exportable from jetstream_pt.third_party.llama import model_original @@ -30,6 +29,8 @@ from jetstream_pt import layers from jetstream_pt import cache_manager +from . import helpers + class ModelComponentTest(unittest.TestCase): """Test diff between original model and xla model for transformer, @@ -65,15 +66,19 @@ def _make_freqs_cis(self, model_arg, seqlen, start_pos): freqs_cis = freqs_cis[start_pos : start_pos + seqlen] return freqs_cis - def _generate_mask(self, cache_length, pos, seqlen): + def _generate_mask(self, cache_length, pos, seqlen, ring_buffer=True): x = jnp.arange(0, cache_length) - cond = jnp.logical_and(x <= pos, x >= pos - seqlen) + if ring_buffer: + cond = jnp.logical_and(x <= pos, x >= pos - seqlen) + else: + # Left aligned buffer we postpone the cache update + cond = jnp.logical_and(x < pos, x >= pos - seqlen) res = jnp.where(cond, 0, float("-inf")) return torchjax.to_torch(res) def _compare_cache(self, cache_torch, cache_jax): _, seq, _, _ = cache_torch.shape - cache_j = torch_xla2.tensor.j2t(cache_jax._elem) + cache_j = torch_xla2.tensor.j2t(cache_jax.jax()) for s in range(seq): print("diff ", (cache_torch[0, s] - cache_j[0, :, s]).norm()) @@ -91,6 +96,7 @@ def _make_one_cache_for_generate(self, env, pos): # pylint: disable-next=all def test_attention(self): + torch.manual_seed(0) env, model_arg = helpers.make_env_tiny(False) attention_orig = model_original.Attention(model_arg) @@ -101,6 +107,7 @@ def test_attention(self): hidden_size=model_arg.dim, device="cpu", env=env, + layer_id=0, ) seqlen = 32 @@ -135,13 +142,14 @@ def test_attention(self): cache_decode = self._make_one_cache_for_generate(env, pos) # insert prefilled cache entry - cache_decode.cache_k._elem = cache_decode.cache_k._elem.at[ - :, :, :pos, : - ].set(cache.cache_k._elem) - - cache_decode.cache_v._elem = cache_decode.cache_v._elem.at[ - :, :, :pos, : - ].set(cache.cache_v._elem) + # pylint: disable-next=all + cache_decode.cache_k._elem = ( + cache_decode.cache_k.jax().at[..., :pos, :].set(cache.cache_k.jax()) + ) + # pylint: disable-next=all + cache_decode.cache_v._elem = ( + cache_decode.cache_v.jax().at[..., :pos, :].set(cache.cache_v.jax()) + ) # self._compare_cache(attention_orig.cache_k, cache_decode.cache_k) # Now do one with decode @@ -154,7 +162,7 @@ def test_attention(self): None, # mask is none for decode ) expected_out = attention_orig(*inputs_orig2) - cache_decode.pos = [pos] # next position to update + cache_decode.input_pos = [pos] # next position to update mask = self._generate_mask(env.cache_sequence_length, pos, seqlen) mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one freqs_cis = freqs_cis.reshape(batch, 1, -1) @@ -170,6 +178,7 @@ def test_attention(self): self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) def test_gemma_attention(self): + """Test gemma attention.""" with jax.default_matmul_precision("float32"): env, model_arg = helpers.make_env_tiny(False) @@ -203,6 +212,7 @@ def init_weights(model): head_dim=head_dim, device="meta", env=env, + layer_id=0, ) def load_hook(state_dict, prefix, *args): @@ -228,8 +238,8 @@ def load_hook(state_dict, prefix, *args): freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos) mask = self._prefill_mask(seqlen, start_pos) kv_write_indexes = torch.arange(0, seqlen) - cache_k = torch.zeros((batch, seqlen, num_heads, head_dim)) - cache_v = torch.zeros((batch, seqlen, num_heads, head_dim)) + cache_k = torch.zeros((batch, seqlen, num_kv_heads, head_dim)) + cache_v = torch.zeros((batch, seqlen, num_kv_heads, head_dim)) inputs_orig = (x, freqs_cis, kv_write_indexes, (cache_k, cache_v), mask) expected_out = attention_orig(*inputs_orig) @@ -299,12 +309,14 @@ def test_transformer_block(self): cache_decode = self._make_one_cache_for_generate(env, pos) # insert prefilled cache entry - cache_decode.cache_k._elem = cache_decode.cache_k._elem.at[ - :, :, :pos, : - ].set(cache.cache_k._elem) - cache_decode.cache_v._elem = cache_decode.cache_v._elem.at[ - :, :, :pos, : - ].set(cache.cache_v._elem) + # pylint: disable-next=all + cache_decode.cache_k._elem = ( + cache_decode.cache_k.jax().at[..., :pos, :].set(cache.cache_k.jax()) + ) + # pylint: disable-next=all + cache_decode.cache_v._elem = ( + cache_decode.cache_v.jax().at[..., :pos, :].set(cache.cache_v.jax()) + ) # Now do one with decode x2 = torch.randn((1, 1, model_arg.dim)) @@ -316,7 +328,7 @@ def test_transformer_block(self): None, # mask is none for decode ) expected_out = block_orig(*inputs_orig2) - cache_decode.pos = [pos] # next position to update + cache_decode.input_pos = [pos] # next position to update mask = self._generate_mask(env.cache_sequence_length, pos, seqlen) mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one freqs_cis = freqs_cis.reshape(batch, 1, -1) @@ -426,14 +438,16 @@ def test_mixtral_transformer(self): self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) def test_mixtral_moe(self): + """Test mixtral moe module.""" config = mixtral_config.ModelArgs() config.intermediate_size = 16 config.dim = 16 m = mixtral.ConditionalFeedForward(config) # random init states = m.state_dict() - for k, v in states.items(): - states[k].normal_() + for _, v in states.items(): + # pylint: disable-next=all + v.normal_() m.load_state_dict(states, assign=True) seqlen = 3 diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 190e6fda..d150c67b 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import functools import unittest import jax import jax.numpy as jnp -import jax.sharding as jsharding import torch import torch_xla2 -from jax.experimental import mesh_utils -from jetstream_pt import cache_manager, layers, quantize, torchjax +from absl.testing import parameterized +from tests import helpers + + +from jetstream_pt import cache_manager, layers, torchjax, environment from jetstream_pt.environment import QuantizationConfig from jetstream_pt.layers import ( WeightOnlyBlockwiseQuantizedLinear, @@ -30,14 +31,12 @@ ) from jetstream_pt.quantize_model import quantize_model from jetstream_pt.quantize import dequantize_tensor, quantize_tensor -from tests import helpers -from torch.utils import _pytree as pytree -from torch_xla2 import tensor + torch.manual_seed(12345) -class QuantizationTest(unittest.TestCase): +class QuantizationTest(parameterized.TestCase): """test kv cache quantization""" def _xla_tensor(self, shape): @@ -70,74 +69,222 @@ def _print_diff(self, w, w_dq): print(" norm: ", (w - w_dq).norm()) print(" cosine dist: ", self._calc_cosine_dist(w, w_dq)) - def test_kv_cache(self): + @parameterized.named_parameters( + ("ring_buffer", True), + ("left_aligned", False), + ) + def test_kv_cache(self, ring_buffer): """test kv cache quantization""" - cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim + + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.quant_config.enable_kv_quantization = True + env_data.batch_size = 4 + + env, _ = helpers.make_env_tiny(True, update_env_data) + + batch = env.batch_size + if env.generate_cache_stacked: + cache_shape = ( + env.num_layers, + batch, + 2, + 100, + 2, + ) # layer, bs, num heads, seqlen, dim + else: + cache_shape = (batch, 2, 100, 2) # bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - env, _ = helpers.make_env_tiny() - cache = cache_manager.Int8KVCacheGenerate.empty( - cache_shape, None, False, env - ) + + cache = cache_manager.Int8KVCacheGenerate.empty(cache_shape, None, env) # seqlen is 1 - k = self._xla_tensor((3, 2, 1, 2)) - v = self._xla_tensor((3, 2, 1, 2)) + k = self._xla_tensor((batch, 2, 1, 2)) + v = self._xla_tensor((batch, 2, 1, 2)) + + def update_finalize_compare(in_k, in_v, in_layer, in_pos): + cache.input_pos = ( + [in_pos] if env.ring_buffer else jnp.array([in_pos] * batch) + ) + + # layer id may or may not take effect, depends on the env config. + cache.update(in_k, in_v, layer_id=in_layer) + cache.finalize() + if env.quant_config.enable_kv_quantization: + new_k = cache.cache_k * cache.k_scaler + new_v = cache.cache_v * cache.v_scaler + else: + new_k = cache.cache_k + new_v = cache.cache_v + + if env.generate_cache_stacked: + self.assertTrue( + jnp.allclose( + k.jax(), + new_k.jax()[in_layer, :, :, in_pos : (in_pos + 1), :], + atol=0.1, + ) + ) + self.assertTrue( + jnp.allclose( + v.jax(), + new_v.jax()[in_layer, :, :, in_pos : (in_pos + 1), :], + atol=0.1, + ) + ) + else: + self.assertTrue( + jnp.allclose( + k.jax(), new_k.jax()[:, :, in_pos : (in_pos + 1), :], atol=0.1 + ) + ) + self.assertTrue( + jnp.allclose( + v.jax(), new_v.jax()[:, :, in_pos : (in_pos + 1), :], atol=0.1 + ) + ) + + update_finalize_compare(k, v, in_layer=1, in_pos=57) + update_finalize_compare(k, v, in_layer=1, in_pos=58) + update_finalize_compare(k, v, in_layer=2, in_pos=3) + + @parameterized.named_parameters( + ("ring_buffer", True), + ("left_aligned", False), + ) + # pylint: disable-next=all + def test_kv_kernel(self, ring_buffer): + """test kv cache quantization""" - cache.input_pos = [57] - new_k, new_v, scaler_k, scaler_v = cache.update(k, v) - new_k = new_k * scaler_k - new_v = new_v * scaler_v + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.quant_config.enable_kv_quantization = True + env_data.batch_size = 4 + + env, _ = helpers.make_env_tiny(False, update_env_data) + + batch = env.batch_size + if env.generate_cache_stacked: + cache_shape = ( + env.num_layers, + batch, + 2, + 100, + 2, + ) # bs, num heads, seqlen, dim + else: + cache_shape = (batch, 2, 100, 2) # layers, bs, num heads, seqlen, dim - self.assertTrue( - jnp.allclose(k._elem, new_k._elem[:, :, 57:58, :], atol=0.1) - ) - self.assertTrue( - jnp.allclose(v._elem, new_v._elem[:, :, 57:58, :], atol=0.1) - ) - - def test_kv_kernel(self): - """test kv cache quantization""" - cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - env, _ = helpers.make_env_tiny(False) + key = jax.random.PRNGKey(123) key2 = jax.random.PRNGKey(456) - cache_k_jax = jax.random.normal(key, cache_shape) - cache_v_jax = jax.random.normal(key2, cache_shape) + cache_k_jax = jax.random.normal(key, cache_shape, dtype=env.default_type) + cache_v_jax = jax.random.normal(key2, cache_shape, dtype=env.default_type) - cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) + start = jnp.zeros((batch,), dtype=jnp.int32) - cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None, env) + cache_k, cache_v, start = torchjax.to_torch( + (cache_k_jax, cache_v_jax, start) + ) + + # Prepare quantized cache before written in + cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (-3, -1)) + cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (-3, -1)) # 1 is seqlen - xq = jax.random.normal(key, (3, 2, 1, 2)) - xk = jax.random.normal(key, (3, 2, 1, 2)) - xv = jax.random.normal(key, (3, 2, 1, 2)) + xq = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) + xk = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) + xv = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) xq, xk, xv = torchjax.to_torch((xq, xk, xv)) - attention_float = layers.AttentionKernel(env) - float_res = attention_float(xq, xk, xv, None, cache) + def get_var(position: int): + pos = ( + [position] + if env.ring_buffer + else jnp.array([position] * batch, dtype=jnp.int64) + ) + mask = jax.lax.broadcast_in_dim( + jnp.array([0] * position + [float("-inf")] * (100 - position)), + (env.batch_size, 1, 1, 100), + (3,), + ) + mask = torchjax.to_torch((mask)) + return pos, mask + + cache = cache_manager.KVCacheGenerate(cache_k, cache_v, None, None, env) + # layer_id doesn't matter, will assign later + attention_float = layers.AttentionKernel(env, layer_id=0) + + float_res = [] + + def update_finalize_record( + in_attention, in_cache, in_q, in_k, in_v, in_layer, in_pos + ): + pos, mask = get_var(in_pos) + in_attention.layer_id = in_layer + in_cache.input_pos = pos + ret = in_attention( + in_q, in_k, in_v, mask, in_cache, start=start, end=pos + ) + in_cache.finalize() + return ret + + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 1, 57) + ) + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 1, 58) + ) + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 2, 3) + ) - # == + # Running into the issue of multiple env object always share the same quant_config. + # Record the results and compare as a workaround. + # pylint: disable-next=all + env._data.quant_config.enable_kv_quantization = True + # pylint: disable-next=all + env = environment.JetEngineEnvironment(env._data) - cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) - cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (1, 3)) - cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (1, 3)) cache_int = cache_manager.Int8KVCacheGenerate( cache_k_int, cache_v_int, cache_k_scaler, cache_v_scaler, - [0], + None, None, env, ) - attention_quant = layers.Int8KVAttentionKernel(env) - int_res = attention_quant(xq, xk, xv, None, cache_int) + # layer_id doesn't matter, will assign later + attention_quant = layers.Int8KVAttentionKernel(env, layer_id=0) + + int_res = [] + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 57) + ) + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 58) + ) + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 2, 3) + ) - self.assertTrue(jnp.allclose(float_res.jax(), int_res.jax(), atol=0.01)) + for f, i in zip(float_res, int_res): + self.assertTrue(jnp.allclose(f.jax(), i.jax(), atol=0.01)) def test_quantize_dequantize_tensor(self): + """Test quantize and dequantize tensor.""" def quantize_dequantize_weight(w, n_bit): # print(f"original w {w}") @@ -187,10 +334,9 @@ def quantize_dequantize_weight(w, n_bit): quantize_dequantize_weight(w, bit) def test_weight_only_quant(self): - + """Test weight only quantization.""" out_features = 2048 in_features = 2048 - block_size = 128 arg = torch.randn(2, 16, in_features).to(torch.bfloat16) nn_linear = torch.nn.Linear( @@ -231,24 +377,27 @@ def test_weight_only_quant(self): in_features, out_features, quant_config=quant_config ) # block_q_linear.run_fake_quantize = True - res, torch_res, block_diff2 = self._nn_linear_run_and_compare( + res, torch_res, _ = self._nn_linear_run_and_compare( nn_linear, block_q_linear, arg ) # self._print_diff(res, torch_res) self.assertLess(per_channel_diff2.norm(), per_channel_diff.norm()) + # pylint: disable-next=all # FIXME: Now asymmetric blockwise quant has higher error than asymmetric per-channel. # self.assertLess(block_diff2.norm(), per_channel_diff2.norm()) def test_int4_weight_loading(self): + """Test int4 weight loading.""" layer = WeightOnlyBlockwiseQuantizedLinear(1024, 2048) state_dict_jax = torchjax.from_torch( helpers.to_xla_tensor(layer.state_dict()) ) state_dict_jax["weight"] = state_dict_jax["weight"].astype(jnp.int4) state_dict_torch = torchjax.to_torch(state_dict_jax) - self.assertTrue(state_dict_torch["weight"]._elem.dtype == jnp.int4) + self.assertTrue(state_dict_torch["weight"].jax().dtype == jnp.int4) def test_blockwise_quantized_linear_sharding(self): + """Test blockwise quantized linear sharding.""" @functools.partial( jax.jit, @@ -264,19 +413,20 @@ def f(layer, weights, args): state_dict_jax = torchjax.from_torch( helpers.to_xla_tensor(layer.state_dict()) ) - input = jax.random.normal( + inputs = jax.random.normal( jax.random.key(0), shape=(2, 32, 1024), dtype=jnp.bfloat16 ) - def shard_and_lower(f, layer, state_dict_jax, input, shardings): + def shard_and_lower(f, layer, state_dict_jax, inputs, shardings): for k, v in state_dict_jax.items(): if k == "weight": state_dict_jax[k] = v.astype(jnp.int4) state_dict_jax[k] = jax.device_put(v, sharding[0]) if k == "weight_scaler": state_dict_jax[k] = jax.device_put(v, sharding[1]) - pre_opt = f.lower(layer, state_dict_jax, input).as_text("hlo") - post_opt = f.lower(layer, state_dict_jax, input).compile().as_text() + # pre opt, for debugging + _ = f.lower(layer, state_dict_jax, inputs).as_text("hlo") + post_opt = f.lower(layer, state_dict_jax, inputs).compile().as_text() return post_opt env, _ = helpers.make_env_tiny() @@ -286,15 +436,16 @@ def shard_and_lower(f, layer, state_dict_jax, input, shardings): # (sharding_by_axis(1), sharding_by_axis(0)), # bad sharding ] for sharding in shardings: - opt_hlo = shard_and_lower(f, layer, state_dict_jax, input, sharding) + opt_hlo = shard_and_lower(f, layer, state_dict_jax, inputs, sharding) self.assertFalse("all-to-all" in opt_hlo) self.assertFalse("all-reduce-scatter" in opt_hlo) def test_activation_quant_per_channel(self): - + """Test activation quantization channel mode.""" out_features = 8 in_features = 4 - block_size = 128 + # Block size + _ = 128 arg = torch.randn(2, 1, in_features).to(torch.bfloat16) nn_linear = torch.nn.Linear( @@ -313,10 +464,11 @@ def test_activation_quant_per_channel(self): self.assertGreater(self._calc_cosine_dist(res, torch_res), 0.9999) def test_quant_creator(self): - + """Test quantization creator.""" out_features = 8 in_features = 4 - block_size = 128 + # Block size + _ = 128 arg = torch.randn(2, 1, in_features).to(torch.bfloat16) nn_linear = torch.nn.Linear( @@ -333,8 +485,10 @@ def test_quant_creator(self): self.assertGreater(self._calc_cosine_dist(res, torch_res), 0.9999) def test_3_layers(self): + """Test 3 layers.""" class Model(torch.nn.Module): + """Model.""" def __init__(self): super().__init__() @@ -343,6 +497,7 @@ def __init__(self): self.linear3 = torch.nn.Linear(2048, 1024, bias=False) def forward(self, x): + """Forward function.""" x = self.linear1(x) x = self.linear2(x) x = self.linear3(x) diff --git a/tests/test_run_server.py b/tests/test_run_server.py index 849af329..73022a74 100644 --- a/tests/test_run_server.py +++ b/tests/test_run_server.py @@ -20,29 +20,37 @@ class MockServer(MagicMock): + """Mock server.""" def run(self, **kwargs): + """Run.""" return self def wait_for_termination(self): + """Wait for termination.""" raise SystemExit("Successfully exited test.") def mock_engine(**kwargs): + """Mock engine.""" return kwargs class ServerRunTest(unittest.TestCase): + """Server run test.""" def reset_flags(self): + """Reset flag.""" flagsaver.restore_flag_values(self.original) def setup(self): + """Setup.""" from run_server import flags - FLAGS = flags.FLAGS + f = flags.FLAGS + # pylint: disable-next=all self.original = flagsaver.save_flag_values() - return FLAGS + return f @parameterized.expand( [ @@ -61,50 +69,51 @@ def test_no_change_from_defaults(self, args, expected): args (List): List to simulate sys.argv with dummy first entry at index 0. expected (str): model_name flag value to inspect """ + # pylint: disable-next=all from run_server import main - FLAGS = self.setup() + f = self.setup() with self.assertRaisesRegex(SystemExit, "Successfully exited test."): app.run(main, args) # run_server - self.assertEqual(FLAGS.port, 9000) - self.assertEqual(FLAGS.threads, 64) - self.assertEqual(FLAGS.config, "InterleavedCPUTestServer") - self.assertEqual(FLAGS.prometheus_port, 0) - self.assertEqual(FLAGS.enable_jax_profiler, False) - self.assertEqual(FLAGS.jax_profiler_port, 9999) + self.assertEqual(f.port, 9000) + self.assertEqual(f.threads, 64) + self.assertEqual(f.config, "InterleavedCPUTestServer") + self.assertEqual(f.prometheus_port, 0) + self.assertEqual(f.enable_jax_profiler, False) + self.assertEqual(f.jax_profiler_port, 9999) # quantization configs - self.assertEqual(FLAGS.quantize_weights, False) - self.assertEqual(FLAGS.quantize_activation, False) - self.assertEqual(FLAGS.quantize_type, "int8_per_channel") - self.assertEqual(FLAGS.quantize_kv_cache, False) + self.assertEqual(f.quantize_weights, False) + self.assertEqual(f.quantize_activation, False) + self.assertEqual(f.quantize_type, "int8_per_channel") + self.assertEqual(f.quantize_kv_cache, False) # engine configs - self.assertEqual(FLAGS.tokenizer_path, None) - self.assertEqual(FLAGS.checkpoint_path, None) - self.assertEqual(FLAGS.bf16_enable, True) - self.assertEqual(FLAGS.context_length, 1024) - self.assertEqual(FLAGS.batch_size, 32) - self.assertEqual(FLAGS.size, "tiny") - self.assertEqual(FLAGS.max_cache_length, 1024) - self.assertEqual(FLAGS.shard_on_batch, False) - self.assertEqual(FLAGS.sharding_config, "") - self.assertEqual(FLAGS.ragged_mha, False) - self.assertEqual(FLAGS.starting_position, 512) - self.assertEqual(FLAGS.temperature, 1.0) - self.assertEqual(FLAGS.sampling_algorithm, "greedy") - self.assertEqual(FLAGS.nucleus_topp, 0.0) - self.assertEqual(FLAGS.topk, 0) - self.assertEqual(FLAGS.ring_buffer, True) + self.assertEqual(f.tokenizer_path, None) + self.assertEqual(f.checkpoint_path, None) + self.assertEqual(f.bf16_enable, True) + self.assertEqual(f.context_length, 1024) + self.assertEqual(f.batch_size, 32) + self.assertEqual(f.size, "tiny") + self.assertEqual(f.max_cache_length, 1024) + self.assertEqual(f.shard_on_batch, False) + self.assertEqual(f.sharding_config, "") + self.assertEqual(f.ragged_mha, False) + self.assertEqual(f.starting_position, 512) + self.assertEqual(f.temperature, 1.0) + self.assertEqual(f.sampling_algorithm, "greedy") + self.assertEqual(f.nucleus_topp, 0.0) + self.assertEqual(f.topk, 0) + self.assertEqual(f.ring_buffer, True) # profiling configs - self.assertEqual(FLAGS.profiling_prefill, False) - self.assertEqual(FLAGS.profiling_output, "") + self.assertEqual(f.profiling_prefill, False) + self.assertEqual(f.profiling_output, "") # model_name flag updates - self.assertEqual(FLAGS.model_name, expected) + self.assertEqual(f.model_name, expected) # reset back to original flags self.reset_flags() @@ -112,7 +121,8 @@ def test_no_change_from_defaults(self, args, expected): @parameterized.expand([param(["test1", "--model_name", "llama3"])]) @patch("jetstream_pt.engine.create_pytorch_engine", mock_engine) def test_call_server_object(self, args): - """tests whether running the main script from absl.app.run launches a server and waits for termination + """tests whether running the main script from absl.app.run launches a server + and waits for termination Args: args (List): List to simulate sys.argv with dummy first entry at index 0. @@ -120,9 +130,10 @@ def test_call_server_object(self, args): with patch( "jetstream.core.server_lib.run", autospec=MockServer().run ) as mock_server: + # pylint: disable-next=all from run_server import main - FLAGS = self.setup() + _ = self.setup() with self.assertRaises(SystemExit): app.run(main, args) self.assertEqual(mock_server.call_count, 1)