29
29
from models .tt_transformers .tt .model_config import DecodersPrecision , ModelArgs , parse_decoder_json
30
30
31
31
32
+ def create_tt_page_table (paged_attention_config , tt_model_args ):
33
+ if paged_attention_config is None :
34
+ return None
35
+
36
+ # Implied shuffling of blocks
37
+ permutation = torch .randperm (paged_attention_config .max_num_blocks )
38
+ # Page table which maps virtual blocks to physical
39
+ reverse_permutation = torch .argsort (permutation )
40
+ return reverse_permutation .reshape (
41
+ tt_model_args .max_batch_size , paged_attention_config .max_num_blocks // tt_model_args .max_batch_size
42
+ )
43
+
44
+
32
45
def create_tt_model (
33
46
mesh_device ,
34
47
instruct ,
@@ -48,26 +61,14 @@ def create_tt_model(
48
61
)
49
62
state_dict = tt_model_args .load_state_dict ()
50
63
51
- page_table = None
52
- paged_attention_config = None
53
- tt_kv_cache = None
54
-
55
- if use_paged_kv_cache :
56
- paged_attention_config = PagedAttentionConfig (
57
- block_size = page_params ["page_block_size" ],
58
- max_num_blocks = page_params ["page_max_num_blocks" ],
59
- )
60
- # Implied shuffling of blocks
61
- permutation = torch .randperm (paged_attention_config .max_num_blocks )
62
- # Page table which maps virtual blocks to physical
63
- reverse_permutation = torch .argsort (permutation )
64
- page_table = reverse_permutation .reshape (
65
- tt_model_args .max_batch_size , paged_attention_config .max_num_blocks // tt_model_args .max_batch_size
66
- )
67
- paged_attention_config = PagedAttentionConfig (
64
+ paged_attention_config = (
65
+ PagedAttentionConfig (
68
66
block_size = page_params ["page_block_size" ],
69
67
max_num_blocks = page_params ["page_max_num_blocks" ],
70
68
)
69
+ if use_paged_kv_cache
70
+ else None
71
+ )
71
72
72
73
model = Transformer (
73
74
args = tt_model_args ,
@@ -78,10 +79,9 @@ def create_tt_model(
78
79
paged_attention_config = paged_attention_config ,
79
80
)
80
81
81
- if use_paged_kv_cache :
82
- tt_kv_cache = [l .attention .layer_past for l in model .layers ]
82
+ tt_kv_cache = [l .attention .layer_past for l in model .layers ] if use_paged_kv_cache else None
83
83
84
- return tt_model_args , model , page_table , tt_kv_cache
84
+ return tt_model_args , model , paged_attention_config , tt_kv_cache
85
85
86
86
87
87
# List of supported Parameters for demo.py
@@ -334,7 +334,7 @@ def test_demo(
334
334
for i in range (repeat_batches ):
335
335
repeat_batch_prompts .append ([input_prompts [(j + i ) % len (input_prompts )] for j in range (len (input_prompts ))])
336
336
337
- model_args , model , page_table , tt_kv_cache = create_tt_model (
337
+ model_args , model , paged_attention_config , tt_kv_cache = create_tt_model (
338
338
mesh_device ,
339
339
instruct = instruct ,
340
340
max_batch_size = batch_size ,
@@ -381,6 +381,10 @@ def test_demo(
381
381
logger .info ("Starting inference..." )
382
382
for batch_idx , input_prompts in enumerate (repeat_batch_prompts ):
383
383
logger .info (f"Processing batch { batch_idx } " )
384
+
385
+ # Create new page table for each batch
386
+ page_table = create_tt_page_table (paged_attention_config , model_args )
387
+
384
388
profiler .start (f"preprocess_prefill_inputs" , iteration = batch_idx )
385
389
text = processor .apply_chat_template (input_prompts , tokenize = False , add_generation_prompt = True )
386
390
image_inputs , video_inputs = process_vision_info (input_prompts )
@@ -458,8 +462,7 @@ def test_demo(
458
462
459
463
# Start decoding
460
464
iteration = 0
461
- # TODO Argmax on device is only supported for batch_size=1
462
- argmax_on_device = False if (batch_size > 1 or sampling_params ["temperature" ] != 0 ) else True
465
+ argmax_on_device = sampling_params ["temperature" ] == 0
463
466
if argmax_on_device :
464
467
device_sampling_params = SamplingParams (temperature = 0.0 , top_k = - 1 , top_p = 1.0 )
465
468
else :
0 commit comments