-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathdflash.py
More file actions
146 lines (122 loc) · 5.54 KB
/
dflash.py
File metadata and controls
146 lines (122 loc) · 5.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import time
from types import SimpleNamespace
import torch
from transformers import AutoModelForCausalLM, DynamicCache
from model import DFlashDraftModel, sample, extract_context_feature
DFLASH_STAGE_ORDER = ("draft", "verify", "commit")
@torch.inference_mode()
def dflash_generate(
model: DFlashDraftModel,
target: AutoModelForCausalLM,
input_ids: torch.Tensor,
mask_token_id: int,
max_new_tokens: int,
block_size: int,
stop_token_ids: list[int],
temperature: float = 0.0,
) -> SimpleNamespace:
num_input_tokens = input_ids.shape[1]
max_length = num_input_tokens + max_new_tokens
output_ids = torch.full(
(1, max_length + block_size),
mask_token_id,
dtype=torch.long,
device=model.device,
)
position_ids = torch.arange(output_ids.shape[1], device=model.device).unsqueeze(0)
stop_token_ids_tensor = None if stop_token_ids is None else torch.tensor(stop_token_ids, device=model.device)
past_key_values_target = DynamicCache()
past_key_values_draft = DynamicCache()
stage_times = empty_stage_times(DFLASH_STAGE_ORDER)
prefill_start = cuda_time()
output = target(
input_ids,
position_ids=position_ids[:, :num_input_tokens],
past_key_values=past_key_values_target,
use_cache=True,
logits_to_keep=1,
output_hidden_states=True if block_size > 1 else False,
)
output_ids[:, :num_input_tokens] = input_ids
output_ids[:, num_input_tokens : num_input_tokens + 1] = sample(output.logits, temperature)
if block_size > 1:
target_hidden = extract_context_feature(output.hidden_states, model.target_layer_ids)
time_to_first_token = cuda_time() - prefill_start
decode_start = cuda_time()
round_clock_start = cuda_time()
start = input_ids.shape[1]
acceptance_lengths = []
round_timestamps = []
draft_prefill = True
while start < max_length:
block_output_ids = output_ids[:, start : start + block_size].clone()
block_position_ids = position_ids[:, start : start + block_size]
if block_size > 1:
draft_stage_start = cuda_time()
noise_embedding = target.model.embed_tokens(block_output_ids)
draft_logits = target.lm_head(model(
target_hidden=target_hidden,
noise_embedding=noise_embedding,
position_ids=position_ids[:, past_key_values_draft.get_seq_length() : start + block_size],
past_key_values=past_key_values_draft,
use_cache=True,
is_causal=False,
)[:, -block_size + 1 :, :])
past_key_values_draft.crop(start)
block_output_ids[:, 1:] = sample(draft_logits)
draft_stage_elapsed = cuda_time() - draft_stage_start
if draft_prefill:
draft_prefill = False
decode_start = cuda_time()
else:
stage_times["draft"] += draft_stage_elapsed
verify_stage_start = cuda_time()
output = target(
block_output_ids,
position_ids=block_position_ids,
past_key_values=past_key_values_target,
use_cache=True,
output_hidden_states=True if block_size > 1 else False,
)
stage_times["verify"] += cuda_time() - verify_stage_start
commit_stage_start = cuda_time()
posterior = sample(output.logits, temperature)
acceptance_length = (block_output_ids[:, 1:] == posterior[:, :-1]).cumprod(dim=1).sum(dim=1)[0].item()
output_ids[:, start : start + acceptance_length + 1] = block_output_ids[:, : acceptance_length + 1]
output_ids[:, start + acceptance_length + 1] = posterior[:, acceptance_length]
acceptance_lengths.append(acceptance_length + 1)
start += acceptance_length + 1
past_key_values_target.crop(start)
if block_size > 1:
target_hidden = extract_context_feature(output.hidden_states, model.target_layer_ids)[:, : acceptance_length + 1, :]
stage_times["commit"] += cuda_time() - commit_stage_start
round_timestamps.append(cuda_time() - round_clock_start)
if stop_token_ids_tensor is not None:
new_tokens = output_ids[:, start - acceptance_length - 1 : start + 1]
if torch.isin(new_tokens[0], stop_token_ids_tensor).any():
break
output_ids = output_ids[:, :max_length]
output_ids = output_ids[:, output_ids[0] != mask_token_id]
if stop_token_ids_tensor is not None:
stop_token_indices = torch.isin(output_ids[0][num_input_tokens:], stop_token_ids_tensor).nonzero(as_tuple=True)[0]
if stop_token_indices.numel() > 0:
output_ids = output_ids[:, : num_input_tokens + stop_token_indices[0] + 1]
num_output_tokens = output_ids.shape[1] - num_input_tokens
total_decode_time = cuda_time() - decode_start
time_per_output_token = total_decode_time / max(num_output_tokens, 1)
return SimpleNamespace(
output_ids=output_ids.cpu(),
num_input_tokens=num_input_tokens,
num_output_tokens=num_output_tokens,
time_to_first_token=time_to_first_token,
time_per_output_token=time_per_output_token,
acceptance_lengths=acceptance_lengths,
decode_rounds=len(acceptance_lengths),
stage_times=stage_times,
round_timestamps=round_timestamps,
)
def cuda_time() -> float:
torch.cuda.synchronize()
return time.perf_counter()
def empty_stage_times(stage_names: tuple[str, ...]) -> dict[str, float]:
return {stage_name: 0.0 for stage_name in stage_names}