Skip to content

Commit 2347ebb

Browse files
enable passing of max_decode_length as a flag
1 parent b1446aa commit 2347ebb

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

jetstream_pt/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
flags.DEFINE_string("size", "tiny", "size of model")
3232
flags.DEFINE_bool("quantize_kv_cache", False, "kv_cache_quantize")
3333
flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize")
34+
flags.DEFINE_integer("max_decode_length", 1024, "max length of generated text")
3435
flags.DEFINE_string("sharding_config", "", "config file for sharding")
3536
flags.DEFINE_bool(
3637
"shard_on_batch",
@@ -173,6 +174,7 @@ def create_engine_from_config_flags():
173174
batch_size=FLAGS.batch_size,
174175
quant_config=quant_config,
175176
max_cache_length=FLAGS.max_cache_length,
177+
max_decode_length=FLAGS.max_decode_length,
176178
sharding_config=sharding_file_name,
177179
shard_on_batch=FLAGS.shard_on_batch,
178180
ragged_mha=FLAGS.ragged_mha,

0 commit comments

Comments
 (0)