Skip to content

Commit 362dfa8

Browse files
authored
Fix qwen25_vl demo test (#26602)
### Ticket Link to Github Issue ### Problem description Provide context for the problem. ### What's changed Describe the approach used to solve the problem. Summarize the changes made and its impact. ### Checklist - [ ] [Single card demo tests](https://github.com/tenstorrent/tt-metal/actions/runs/16867424397) CI passes
1 parent 3883482 commit 362dfa8

File tree

1 file changed

+26
-23
lines changed

1 file changed

+26
-23
lines changed

models/demos/qwen25_vl/demo/demo.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@
2929
from models.tt_transformers.tt.model_config import DecodersPrecision, ModelArgs, parse_decoder_json
3030

3131

32+
def create_tt_page_table(paged_attention_config, tt_model_args):
33+
if paged_attention_config is None:
34+
return None
35+
36+
# Implied shuffling of blocks
37+
permutation = torch.randperm(paged_attention_config.max_num_blocks)
38+
# Page table which maps virtual blocks to physical
39+
reverse_permutation = torch.argsort(permutation)
40+
return reverse_permutation.reshape(
41+
tt_model_args.max_batch_size, paged_attention_config.max_num_blocks // tt_model_args.max_batch_size
42+
)
43+
44+
3245
def create_tt_model(
3346
mesh_device,
3447
instruct,
@@ -48,26 +61,14 @@ def create_tt_model(
4861
)
4962
state_dict = tt_model_args.load_state_dict()
5063

51-
page_table = None
52-
paged_attention_config = None
53-
tt_kv_cache = None
54-
55-
if use_paged_kv_cache:
56-
paged_attention_config = PagedAttentionConfig(
57-
block_size=page_params["page_block_size"],
58-
max_num_blocks=page_params["page_max_num_blocks"],
59-
)
60-
# Implied shuffling of blocks
61-
permutation = torch.randperm(paged_attention_config.max_num_blocks)
62-
# Page table which maps virtual blocks to physical
63-
reverse_permutation = torch.argsort(permutation)
64-
page_table = reverse_permutation.reshape(
65-
tt_model_args.max_batch_size, paged_attention_config.max_num_blocks // tt_model_args.max_batch_size
66-
)
67-
paged_attention_config = PagedAttentionConfig(
64+
paged_attention_config = (
65+
PagedAttentionConfig(
6866
block_size=page_params["page_block_size"],
6967
max_num_blocks=page_params["page_max_num_blocks"],
7068
)
69+
if use_paged_kv_cache
70+
else None
71+
)
7172

7273
model = Transformer(
7374
args=tt_model_args,
@@ -78,10 +79,9 @@ def create_tt_model(
7879
paged_attention_config=paged_attention_config,
7980
)
8081

81-
if use_paged_kv_cache:
82-
tt_kv_cache = [l.attention.layer_past for l in model.layers]
82+
tt_kv_cache = [l.attention.layer_past for l in model.layers] if use_paged_kv_cache else None
8383

84-
return tt_model_args, model, page_table, tt_kv_cache
84+
return tt_model_args, model, paged_attention_config, tt_kv_cache
8585

8686

8787
# List of supported Parameters for demo.py
@@ -334,7 +334,7 @@ def test_demo(
334334
for i in range(repeat_batches):
335335
repeat_batch_prompts.append([input_prompts[(j + i) % len(input_prompts)] for j in range(len(input_prompts))])
336336

337-
model_args, model, page_table, tt_kv_cache = create_tt_model(
337+
model_args, model, paged_attention_config, tt_kv_cache = create_tt_model(
338338
mesh_device,
339339
instruct=instruct,
340340
max_batch_size=batch_size,
@@ -381,6 +381,10 @@ def test_demo(
381381
logger.info("Starting inference...")
382382
for batch_idx, input_prompts in enumerate(repeat_batch_prompts):
383383
logger.info(f"Processing batch {batch_idx}")
384+
385+
# Create new page table for each batch
386+
page_table = create_tt_page_table(paged_attention_config, model_args)
387+
384388
profiler.start(f"preprocess_prefill_inputs", iteration=batch_idx)
385389
text = processor.apply_chat_template(input_prompts, tokenize=False, add_generation_prompt=True)
386390
image_inputs, video_inputs = process_vision_info(input_prompts)
@@ -458,8 +462,7 @@ def test_demo(
458462

459463
# Start decoding
460464
iteration = 0
461-
# TODO Argmax on device is only supported for batch_size=1
462-
argmax_on_device = False if (batch_size > 1 or sampling_params["temperature"] != 0) else True
465+
argmax_on_device = sampling_params["temperature"] == 0
463466
if argmax_on_device:
464467
device_sampling_params = SamplingParams(temperature=0.0, top_k=-1, top_p=1.0)
465468
else:

0 commit comments

Comments
 (0)