Skip to content

Add gemma and update recent changes to multiple host #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions jetstream_pt/ray_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,14 @@ def create_pytorch_ray_engine(
quantize_weights=False,
quantize_kv=False,
max_cache_length=1024,
sharding_config=None,
) -> PyTorchRayEngine:

supported_models = ["llama-2", "llama-3", "gemma"]
if model_name not in supported_models:
raise NotImplementedError(
f"Model name should be one of{','.join(supported_models)}"
)
ray.init(ignore_reinit_error=True)
pod_name = tpu.get_current_pod_name()
num_hosts = tpu.get_current_pod_worker_count()
Expand Down Expand Up @@ -183,6 +189,7 @@ def create_pytorch_ray_engine(
quantize_weights=quantize_weights,
quantize_kv=quantize_kv,
max_cache_length=max_cache_length,
sharding_config=sharding_config,
)
engine_workers.append(engine_worker)
engine_master = PyTorchRayEngine(
Expand Down
50 changes: 38 additions & 12 deletions jetstream_pt/ray_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, List, Optional, Tuple, Union
import threading
import functools
import os
import humanize


Expand All @@ -39,6 +40,7 @@
from jetstream_pt import cache_manager
from jetstream_pt import quantize
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model


Mesh = jax.sharding.Mesh
Expand Down Expand Up @@ -103,6 +105,7 @@ def __init__(
quantize_weights=False,
quantize_kv=False,
max_cache_length=1024,
sharding_config=None,
):

jax.config.update("jax_default_prng_impl", "unsafe_rbg")
Expand Down Expand Up @@ -144,38 +147,61 @@ def __init__(
checkpoint_format = "safetensors"
checkpoint_path = paths[0]

if not sharding_config:
sharding_config = os.path.join("default_shardings", model_name + ".yaml")

env_data = JetEngineEnvironmentData(
tokenizer_path=tokenizer_path,
checkpoint_path=checkpoint_path,
checkpoint_format=checkpoint_format,
model_type="llama-2-" + param_size,
batch_size=batch_size,
max_decode_length=max_decode_length,
max_input_sequence_length=context_length,
enable_weight_quantization=quantize_weights,
enable_kv_quantization=quantize_kv,
cache_sequence_length=max_cache_length,
bf16_enable=bf16_enable,
sharding_config_path=sharding_config,
)
env = JetEngineEnvironment(env_data)

pt_model = None
if "llama" in model_name:
if model_name.startswith("llama"):

args = model_args.get_model_args(
model_name + "-" + param_size,
context_length,
batch_size,
bf16_enable,
model_name + "-" + param_size, context_length, batch_size, bf16_enable
)
args.device = "meta"
args.quantize = quantize_weights
env_data.cache_shape = (
batch_size,
args.n_kv_heads,
max_cache_length,
args.dim // args.n_heads,
)
env_data.model_type = "llama-2-" + param_size
env_data.num_layers = args.n_layers
env = JetEngineEnvironment(env_data)
pt_model = model_exportable.Transformer(args, env)
elif model_name == "gemma":
args = gemma_config.get_model_config(param_size)
env_data.cache_shape = (
batch_size,
args.num_key_value_heads,
max_cache_length,
args.head_dim,
)
env_data.model_type = "gemma-" + param_size
env_data.num_layers = args.num_hidden_layers
env = JetEngineEnvironment(env_data)
pt_model = gemma_model.GemmaModel(args, env)
else:
raise RuntimeError(f"Model with name {model_name} not found")

num_params_size = 0
num_params = 0
for _, v in pt_model.state_dict().items():
num_params += 1
num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2)
num_params_size = 0
num_params = 0
for _, v in pt_model.state_dict().items():
num_params += 1
num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2)
print("Number of param Gbytes:", num_params_size / (1 << 30))
print("Number of param: ", num_params)

Expand Down
10 changes: 10 additions & 0 deletions run_interactive_multiple_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@
"max_cache_length", 1024, "kv_cache_quantize"
)

_MODEL_NAME = flags.DEFINE_string(
"model_name", None, "model type", required=False
)

_SHARDING_CONFIG = flags.DEFINE_string(
"sharding_config", "", "config file for sharding"
)


def create_engine():
"""create a pytorch engine"""
Expand All @@ -73,6 +81,7 @@ def create_engine():

start = time.perf_counter()
engine = ray_engine.create_pytorch_ray_engine(
model_name=_MODEL_NAME.value,
tokenizer_path=_TOKENIZER_PATH.value,
ckpt_path=_CKPT_PATH.value,
bf16_enable=True,
Expand All @@ -82,6 +91,7 @@ def create_engine():
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
sharding_config=_SHARDING_CONFIG.value,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down
Loading