@@ -49,50 +49,6 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
49
49
n_seq_max
50
50
)) {}
51
51
52
- void llama_kv_cache_hybrid_recurrent::clear () {
53
- kv_attn ->clear ();
54
- kv_recurrent->clear ();
55
- }
56
-
57
- bool llama_kv_cache_hybrid_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
58
- // Try removing from the recurrent cache first since it may fail. If it does
59
- // fail, the cache will not have been mutated.
60
- if (!kv_recurrent->seq_rm (seq_id, p0, p1)) {
61
- return false ;
62
- }
63
- return kv_attn->seq_rm (seq_id, p0, p1);
64
- }
65
-
66
- void llama_kv_cache_hybrid_recurrent::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
67
- kv_attn ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
68
- kv_recurrent->seq_cp (seq_id_src, seq_id_dst, p0, p1);
69
- }
70
-
71
- void llama_kv_cache_hybrid_recurrent::seq_keep (llama_seq_id seq_id) {
72
- kv_attn ->seq_keep (seq_id);
73
- kv_recurrent->seq_keep (seq_id);
74
- }
75
-
76
- void llama_kv_cache_hybrid_recurrent::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
77
- kv_attn->seq_add (seq_id, p0, p1, shift);
78
- kv_recurrent->seq_add (seq_id, p0, p1, shift);
79
- }
80
-
81
- void llama_kv_cache_hybrid_recurrent::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
82
- kv_attn ->seq_div (seq_id, p0, p1, d);
83
- kv_recurrent->seq_div (seq_id, p0, p1, d);
84
- }
85
-
86
- llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min (llama_seq_id seq_id) const {
87
- // the min of the total cache is the max of the two caches' min values
88
- return std::max (kv_attn->seq_pos_min (seq_id), kv_recurrent->seq_pos_min (seq_id));
89
- }
90
-
91
- llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max (llama_seq_id seq_id) const {
92
- // the max of the total cache is the min of the two caches' max values
93
- return std::min (kv_attn->seq_pos_max (seq_id), kv_recurrent->seq_pos_max (seq_id));
94
- }
95
-
96
52
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch (const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
97
53
98
54
// since this includes a recurrent cache, we cannot use split_simple
@@ -135,23 +91,59 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
135
91
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this );
136
92
}
137
93
138
- bool llama_kv_cache_hybrid_recurrent::update (llama_context & lctx) {
139
- bool res = false ;
94
+ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update (llama_context * lctx, bool optimize) {
95
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
96
+ this ,
97
+ static_cast <llama_kv_cache_unified_state *>( kv_attn ->init_update (lctx, optimize).release ()),
98
+ static_cast <llama_kv_cache_recurrent_state *>(kv_recurrent->init_update (lctx, optimize).release ()));
99
+ }
100
+
101
+ bool llama_kv_cache_hybrid_recurrent::get_can_shift () const {
102
+ // Shifting is trivially supported for recurrent
103
+ return kv_attn->get_can_shift ();
104
+ }
105
+ void llama_kv_cache_hybrid_recurrent::clear () {
106
+ kv_attn ->clear ();
107
+ kv_recurrent->clear ();
108
+ }
140
109
141
- res = res | kv_attn ->update (lctx);
142
- res = res | kv_recurrent->update (lctx);
110
+ bool llama_kv_cache_hybrid_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
111
+ // Try removing from the recurrent cache first since it may fail. If it does
112
+ // fail, the cache will not have been mutated.
113
+ if (!kv_recurrent->seq_rm (seq_id, p0, p1)) {
114
+ return false ;
115
+ }
116
+ return kv_attn->seq_rm (seq_id, p0, p1);
117
+ }
143
118
144
- return res;
119
+ void llama_kv_cache_hybrid_recurrent::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
120
+ kv_attn ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
121
+ kv_recurrent->seq_cp (seq_id_src, seq_id_dst, p0, p1);
145
122
}
146
123
147
- void llama_kv_cache_hybrid_recurrent::defrag_sched ( float thold ) {
148
- kv_attn ->defrag_sched (thold );
149
- kv_recurrent->defrag_sched (thold );
124
+ void llama_kv_cache_hybrid_recurrent::seq_keep (llama_seq_id seq_id ) {
125
+ kv_attn ->seq_keep (seq_id );
126
+ kv_recurrent->seq_keep (seq_id );
150
127
}
151
128
152
- bool llama_kv_cache_hybrid_recurrent::get_can_shift () const {
153
- // Shifting is trivially supported for recurrent
154
- return kv_attn->get_can_shift ();
129
+ void llama_kv_cache_hybrid_recurrent::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
130
+ kv_attn->seq_add (seq_id, p0, p1, shift);
131
+ kv_recurrent->seq_add (seq_id, p0, p1, shift);
132
+ }
133
+
134
+ void llama_kv_cache_hybrid_recurrent::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
135
+ kv_attn ->seq_div (seq_id, p0, p1, d);
136
+ kv_recurrent->seq_div (seq_id, p0, p1, d);
137
+ }
138
+
139
+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min (llama_seq_id seq_id) const {
140
+ // the min of the total cache is the max of the two caches' min values
141
+ return std::max (kv_attn->seq_pos_min (seq_id), kv_recurrent->seq_pos_min (seq_id));
142
+ }
143
+
144
+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max (llama_seq_id seq_id) const {
145
+ // the max of the total cache is the min of the two caches' max values
146
+ return std::min (kv_attn->seq_pos_max (seq_id), kv_recurrent->seq_pos_max (seq_id));
155
147
}
156
148
157
149
void llama_kv_cache_hybrid_recurrent::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
@@ -173,13 +165,24 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c
173
165
}
174
166
175
167
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (llama_memory_status status)
176
- : status(status), state_attn(status), state_recurrent(status) {}
168
+ : status(status),
169
+ state_attn(new llama_kv_cache_unified_state(status)),
170
+ state_recurrent(new llama_kv_cache_recurrent_state(status)) {}
177
171
178
172
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (llama_kv_cache_hybrid_recurrent * kv)
179
173
: status(LLAMA_MEMORY_STATUS_SUCCESS),
180
174
kv(kv),
181
- state_attn(status, kv->get_kv_attn ()),
182
- state_recurrent(status, kv->get_kv_recurrent ()) {}
175
+ state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn ())),
176
+ state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent ())) {}
177
+
178
+ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (
179
+ llama_kv_cache_hybrid_recurrent * kv,
180
+ llama_kv_cache_unified_state * state_unified,
181
+ llama_kv_cache_recurrent_state * state_recurrent)
182
+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
183
+ kv(kv),
184
+ state_attn(state_unified),
185
+ state_recurrent(state_recurrent) {}
183
186
184
187
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (
185
188
llama_kv_cache_hybrid_recurrent * kv,
@@ -194,8 +197,8 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
194
197
// NOTE: these child states are only used as wrapper APIs for the
195
198
// const methods, so we use the "init full" signature since the
196
199
// actual state is not used.
197
- state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn ()),
198
- state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent ()) {}
200
+ state_attn(new llama_kv_cache_unified_state( kv->get_kv_attn () )),
201
+ state_recurrent(new llama_kv_cache_recurrent_state( LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent () )) {}
199
202
200
203
201
204
bool llama_kv_cache_hybrid_recurrent_state::next () {
@@ -232,10 +235,10 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const {
232
235
return ubatches[i_next];
233
236
}
234
237
235
- const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const {
236
- return & state_attn;
238
+ const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const {
239
+ return state_attn. get () ;
237
240
}
238
241
239
242
const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent () const {
240
- return & state_recurrent;
243
+ return state_recurrent. get () ;
241
244
}
0 commit comments