@@ -52,24 +52,7 @@ def create_engine(**kwargs):
52
52
os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0"
53
53
54
54
start = time .perf_counter ()
55
- engine = ray_engine .create_pytorch_ray_engine (
56
- model_name = kwargs ["model_name" ],
57
- tokenizer_path = kwargs ["tokenizer_path" ],
58
- ckpt_path = kwargs ["ckpt_path" ],
59
- bf16_enable = kwargs ["bf16_enable" ],
60
- param_size = kwargs ["param_size" ],
61
- context_length = kwargs ["context_length" ],
62
- batch_size = kwargs ["batch_size" ],
63
- quantize_weights = kwargs ["quantize_weights" ],
64
- quantize_kv = kwargs ["quantize_kv" ],
65
- max_cache_length = kwargs ["max_cache_length" ],
66
- sharding_config = kwargs ["sharding_config" ],
67
- num_hosts = kwargs ["num_hosts" ],
68
- worker_chips = kwargs ["worker_chips" ],
69
- tpu_chips = kwargs ["tpu_chips" ],
70
- enable_jax_profiler = kwargs ["enable_jax_profiler" ],
71
- jax_profiler_port = kwargs ["jax_profiler_port" ],
72
- )
55
+ engine = ray_engine .create_pytorch_ray_engine (** kwargs )
73
56
74
57
print ("Initialize engine" , time .perf_counter () - start )
75
58
return engine
0 commit comments