Skip to content

Commit 85d2917

Browse files
committed
fix: Update recurrent cache for changes to remove intermediate kv_cache interface
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 6388cf6 commit 85d2917

File tree

2 files changed

+93
-86
lines changed

2 files changed

+93
-86
lines changed

src/llama-kv-cache-hybrid-recurrent.cpp

Lines changed: 66 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -49,50 +49,6 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
4949
n_seq_max
5050
)) {}
5151

52-
void llama_kv_cache_hybrid_recurrent::clear() {
53-
kv_attn ->clear();
54-
kv_recurrent->clear();
55-
}
56-
57-
bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
58-
// Try removing from the recurrent cache first since it may fail. If it does
59-
// fail, the cache will not have been mutated.
60-
if (!kv_recurrent->seq_rm(seq_id, p0, p1)) {
61-
return false;
62-
}
63-
return kv_attn->seq_rm(seq_id, p0, p1);
64-
}
65-
66-
void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
67-
kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
68-
kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1);
69-
}
70-
71-
void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) {
72-
kv_attn ->seq_keep(seq_id);
73-
kv_recurrent->seq_keep(seq_id);
74-
}
75-
76-
void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
77-
kv_attn->seq_add(seq_id, p0, p1, shift);
78-
kv_recurrent->seq_add(seq_id, p0, p1, shift);
79-
}
80-
81-
void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
82-
kv_attn ->seq_div(seq_id, p0, p1, d);
83-
kv_recurrent->seq_div(seq_id, p0, p1, d);
84-
}
85-
86-
llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const {
87-
// the min of the total cache is the max of the two caches' min values
88-
return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id));
89-
}
90-
91-
llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const {
92-
// the max of the total cache is the min of the two caches' max values
93-
return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id));
94-
}
95-
9652
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
9753

9854
// since this includes a recurrent cache, we cannot use split_simple
@@ -135,23 +91,59 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
13591
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this);
13692
}
13793

138-
bool llama_kv_cache_hybrid_recurrent::update(llama_context & lctx) {
139-
bool res = false;
94+
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) {
95+
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
96+
this,
97+
static_cast<llama_kv_cache_unified_state *>( kv_attn ->init_update(lctx, optimize).release()),
98+
static_cast<llama_kv_cache_recurrent_state *>(kv_recurrent->init_update(lctx, optimize).release()));
99+
}
100+
101+
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
102+
// Shifting is trivially supported for recurrent
103+
return kv_attn->get_can_shift();
104+
}
105+
void llama_kv_cache_hybrid_recurrent::clear() {
106+
kv_attn ->clear();
107+
kv_recurrent->clear();
108+
}
140109

141-
res = res | kv_attn ->update(lctx);
142-
res = res | kv_recurrent->update(lctx);
110+
bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
111+
// Try removing from the recurrent cache first since it may fail. If it does
112+
// fail, the cache will not have been mutated.
113+
if (!kv_recurrent->seq_rm(seq_id, p0, p1)) {
114+
return false;
115+
}
116+
return kv_attn->seq_rm(seq_id, p0, p1);
117+
}
143118

144-
return res;
119+
void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
120+
kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
121+
kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1);
145122
}
146123

147-
void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) {
148-
kv_attn ->defrag_sched(thold);
149-
kv_recurrent->defrag_sched(thold);
124+
void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) {
125+
kv_attn ->seq_keep(seq_id);
126+
kv_recurrent->seq_keep(seq_id);
150127
}
151128

152-
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
153-
// Shifting is trivially supported for recurrent
154-
return kv_attn->get_can_shift();
129+
void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
130+
kv_attn->seq_add(seq_id, p0, p1, shift);
131+
kv_recurrent->seq_add(seq_id, p0, p1, shift);
132+
}
133+
134+
void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
135+
kv_attn ->seq_div(seq_id, p0, p1, d);
136+
kv_recurrent->seq_div(seq_id, p0, p1, d);
137+
}
138+
139+
llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const {
140+
// the min of the total cache is the max of the two caches' min values
141+
return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id));
142+
}
143+
144+
llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const {
145+
// the max of the total cache is the min of the two caches' max values
146+
return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id));
155147
}
156148

157149
void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
@@ -173,13 +165,24 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c
173165
}
174166

175167
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status)
176-
: status(status), state_attn(status), state_recurrent(status) {}
168+
: status(status),
169+
state_attn(new llama_kv_cache_unified_state(status)),
170+
state_recurrent(new llama_kv_cache_recurrent_state(status)) {}
177171

178172
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
179173
: status(LLAMA_MEMORY_STATUS_SUCCESS),
180174
kv(kv),
181-
state_attn(status, kv->get_kv_attn()),
182-
state_recurrent(status, kv->get_kv_recurrent()) {}
175+
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
176+
state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {}
177+
178+
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
179+
llama_kv_cache_hybrid_recurrent * kv,
180+
llama_kv_cache_unified_state * state_unified,
181+
llama_kv_cache_recurrent_state * state_recurrent)
182+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
183+
kv(kv),
184+
state_attn(state_unified),
185+
state_recurrent(state_recurrent) {}
183186

184187
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
185188
llama_kv_cache_hybrid_recurrent * kv,
@@ -194,8 +197,8 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
194197
// NOTE: these child states are only used as wrapper APIs for the
195198
// const methods, so we use the "init full" signature since the
196199
// actual state is not used.
197-
state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn()),
198-
state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent()) {}
200+
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
201+
state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {}
199202

200203

201204
bool llama_kv_cache_hybrid_recurrent_state::next() {
@@ -232,10 +235,10 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const {
232235
return ubatches[i_next];
233236
}
234237

235-
const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const {
236-
return &state_attn;
238+
const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const {
239+
return state_attn.get();
237240
}
238241

239242
const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const {
240-
return &state_recurrent;
243+
return state_recurrent.get();
241244
}

src/llama-kv-cache-hybrid-recurrent.h

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
#include "llama-batch.h"
44
#include "llama-graph.h"
5-
#include "llama-kv-cache.h"
65
#include "llama-kv-cache-recurrent.h"
76
#include "llama-kv-cache-unified.h"
7+
#include "llama-kv-cells.h"
8+
#include "llama-memory.h"
89

910
#include <memory>
1011
#include <vector>
@@ -16,7 +17,7 @@
1617
// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to
1718
// support models where each layer may be either attention-based or recurrent
1819

19-
class llama_kv_cache_hybrid_recurrent : public llama_kv_cache {
20+
class llama_kv_cache_hybrid_recurrent : public llama_memory_i {
2021
public:
2122
llama_kv_cache_hybrid_recurrent(
2223
const llama_model & model,
@@ -42,21 +43,6 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache {
4243
// llama_memory_i
4344
//
4445

45-
void clear() override;
46-
47-
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
48-
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
49-
void seq_keep(llama_seq_id seq_id) override;
50-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
51-
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
52-
53-
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
54-
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
55-
56-
//
57-
// llama_kv_cache
58-
//
59-
6046
llama_memory_state_ptr init_batch(
6147
const llama_batch & batch,
6248
uint32_t n_ubatch,
@@ -65,12 +51,21 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache {
6551

6652
llama_memory_state_ptr init_full() override;
6753

68-
bool update(llama_context & lctx) override;
69-
70-
void defrag_sched(float thold) override;
54+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
7155

7256
bool get_can_shift() const override;
7357

58+
void clear() override;
59+
60+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
61+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
62+
void seq_keep(llama_seq_id seq_id) override;
63+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
64+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
65+
66+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
67+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
68+
7469
// state write/load
7570

7671
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -92,12 +87,21 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache {
9287

9388
class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
9489
public:
90+
using llama_kv_cache_unified_state_ptr = std::unique_ptr<llama_kv_cache_unified_state>;
91+
using llama_kv_cache_recurrent_state_ptr = std::unique_ptr<llama_kv_cache_recurrent_state>;
92+
9593
// init failure
9694
explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status);
9795

9896
// init full
9997
explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv);
10098

99+
// init update
100+
explicit llama_kv_cache_hybrid_recurrent_state(
101+
llama_kv_cache_hybrid_recurrent * kv,
102+
llama_kv_cache_unified_state * state_unified,
103+
llama_kv_cache_recurrent_state * state_recurrent);
104+
101105
// init success
102106
llama_kv_cache_hybrid_recurrent_state(
103107
llama_kv_cache_hybrid_recurrent * kv,
@@ -116,7 +120,7 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
116120
const llama_ubatch & get_ubatch() const override;
117121

118122
//
119-
// llama_kv_cache_hybrid_recurrent_state_i
123+
// llama_kv_cache_hybrid_recurrent_state
120124
//
121125

122126
const llama_kv_cache_unified_state * get_state_attn () const;
@@ -135,6 +139,6 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
135139
std::vector<uint32_t> heads_attn;
136140
std::vector<llama_ubatch> ubatches;
137141

138-
const llama_kv_cache_unified_state state_attn;
139-
const llama_kv_cache_recurrent_state state_recurrent;
142+
const llama_kv_cache_unified_state_ptr state_attn;
143+
const llama_kv_cache_recurrent_state_ptr state_recurrent;
140144
};

0 commit comments

Comments
 (0)