Skip to content

Commit df92015

Browse files
yixinshiYixin Shi
and
Yixin Shi
authored
Use kwargs to simplify the call sites a bit (#175)
* Use kwargs to simplify the call sites a bit * fix pyink error. --------- Co-authored-by: Yixin Shi <[email protected]>
1 parent 9c555a8 commit df92015

File tree

2 files changed

+2
-18
lines changed

2 files changed

+2
-18
lines changed

jetstream_pt/ray_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def create_pytorch_ray_engine(
220220
decode_pod_slice_name: str = None,
221221
enable_jax_profiler: bool = False,
222222
jax_profiler_port: int = 9999,
223+
**kwargs,
223224
) -> Union[
224225
PyTorchRayEngine, Tuple[List[PyTorchRayEngine], List[PyTorchRayEngine]]
225226
]:

run_ray_serve_interleave.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,7 @@ def create_engine(**kwargs):
5252
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
5353

5454
start = time.perf_counter()
55-
engine = ray_engine.create_pytorch_ray_engine(
56-
model_name=kwargs["model_name"],
57-
tokenizer_path=kwargs["tokenizer_path"],
58-
ckpt_path=kwargs["ckpt_path"],
59-
bf16_enable=kwargs["bf16_enable"],
60-
param_size=kwargs["param_size"],
61-
context_length=kwargs["context_length"],
62-
batch_size=kwargs["batch_size"],
63-
quantize_weights=kwargs["quantize_weights"],
64-
quantize_kv=kwargs["quantize_kv"],
65-
max_cache_length=kwargs["max_cache_length"],
66-
sharding_config=kwargs["sharding_config"],
67-
num_hosts=kwargs["num_hosts"],
68-
worker_chips=kwargs["worker_chips"],
69-
tpu_chips=kwargs["tpu_chips"],
70-
enable_jax_profiler=kwargs["enable_jax_profiler"],
71-
jax_profiler_port=kwargs["jax_profiler_port"],
72-
)
55+
engine = ray_engine.create_pytorch_ray_engine(**kwargs)
7356

7457
print("Initialize engine", time.perf_counter() - start)
7558
return engine

0 commit comments

Comments
 (0)