Skip to content

Commit 9717eb9

Browse files
authored
Fix exception in ray_worker (#144)
* Fix exception in ray_worker * fix format
1 parent 508d71c commit 9717eb9

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

jetstream_pt/ray_worker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,15 +350,17 @@ def _call_model_generate(
350350
new_mask = mask.at[:, current_position].set(0)
351351
if self.env.quant_config.enable_kv_quantization:
352352
caches_obj = [
353-
cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes)
353+
cache_manager.Int8KVCacheGenerate(
354+
k, v, ks, vs, input_indexes, env=self.env
355+
)
354356
for (k, v), (ks, vs) in torchjax.to_torch(
355357
list(zip(caches, cache_scales))
356358
)
357359
]
358360
else:
359361
caches_obj = [
360362
cache_manager.KVCacheGenerate(
361-
k, v, input_indexes, self.cache_sharding
363+
k, v, input_indexes, self.cache_sharding, env=self.env
362364
)
363365
for k, v in torchjax.to_torch(caches)
364366
]

0 commit comments

Comments
 (0)