diff --git a/README.md b/README.md index b28e5438..bb7cd63c 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,41 @@ Optional flags: * `--sharding_config=` This makes use of alternative sharding config instead of the ones in default_shardings directory. + +# Run the server with ray +Below are steps run server with ray: +1. Ssh to Cloud Multiple Host TPU VM (v5e-16 TPU VM) +2. Step 2 to step 5 in Outline +3. Setup ray cluster +4. Run server with ray + +## Setup Ray Cluster +Login host 0 VM, start ray head with below command: + +```bash + +ray start --head + +``` + +Login other host VMs, start ray head with below command: + +```bash + +ray start --address='$ip:$port' + +``` + +Note: Get address ip and port information from ray head. + +## Run server with ray + +Here is an example to run the server with ray for llama2 7B model: + +```bash +python run_server_with_ray.py --tpu_chips=16 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" +``` + # Run benchmark Start the server and then go to the deps/JetStream folder (downloaded during `install_everything.sh`)