Commit 2e66bc6
Cache class working with generate (#1)
* Draft version of new KV Caching
This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks)
/ StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented
in a third-party or in transformers directly
* Address numerous PR suggestions
1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic.
2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls.
3. Remove __bool__ and __getitem__ magic as they're confusing.
4. past_key_values.update(key, value, idx) now returns key, value.
5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR.
6. Separate key_cache and value_cache.
Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method.
* Integrate (Sink)Cache with Llama FA2
* Move from/to_legacy_cache to ...Model class
* Undo unnecessary newline change
* Match import style
* working generate
* Add tests; Simplify code; Apply changes to Mistral and Persimmon
* fix rebase mess
* a few more manual fixes
* last manual fix
* propagate changes to phi
* upgrade test
* add use_legacy_cache docstring; beef up tests
* reintroduce unwanted deletes
---------
Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>1 parent a40037d commit 2e66bc6
File tree
8 files changed
+232
-107
lines changed- src/transformers
- generation
- models
- llama
- mistral
- persimmon
- phi
- tests/generation
8 files changed
+232
-107
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1303 | 1303 | | |
1304 | 1304 | | |
1305 | 1305 | | |
| 1306 | + | |
1306 | 1307 | | |
1307 | 1308 | | |
1308 | 1309 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | | - | |
2 | | - | |
| 1 | + | |
3 | 2 | | |
4 | 3 | | |
5 | 4 | | |
6 | 5 | | |
7 | | - | |
8 | | - | |
9 | | - | |
10 | | - | |
| 6 | + | |
11 | 7 | | |
12 | | - | |
13 | | - | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
14 | 37 | | |
15 | | - | |
16 | 38 | | |
17 | 39 | | |
18 | 40 | | |
| |||
21 | 43 | | |
22 | 44 | | |
23 | 45 | | |
24 | | - | |
| 46 | + | |
25 | 47 | | |
26 | | - | |
27 | | - | |
| 48 | + | |
| 49 | + | |
28 | 50 | | |
29 | 51 | | |
30 | 52 | | |
| |||
53 | 75 | | |
54 | 76 | | |
55 | 77 | | |
56 | | - | |
57 | | - | |
58 | | - | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
59 | 81 | | |
60 | 82 | | |
61 | 83 | | |
| |||
109 | 131 | | |
110 | 132 | | |
111 | 133 | | |
112 | | - | |
| 134 | + | |
113 | 135 | | |
114 | 136 | | |
115 | 137 | | |
| |||
122 | 144 | | |
123 | 145 | | |
124 | 146 | | |
125 | | - | |
| 147 | + | |
126 | 148 | | |
127 | | - | |
128 | | - | |
| 149 | + | |
| 150 | + | |
129 | 151 | | |
130 | 152 | | |
131 | 153 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| 27 | + | |
27 | 28 | | |
28 | 29 | | |
29 | 30 | | |
| |||
3226 | 3227 | | |
3227 | 3228 | | |
3228 | 3229 | | |
| 3230 | + | |
| 3231 | + | |
3229 | 3232 | | |
3230 | 3233 | | |
3231 | 3234 | | |
| |||
3561 | 3564 | | |
3562 | 3565 | | |
3563 | 3566 | | |
| 3567 | + | |
| 3568 | + | |
3564 | 3569 | | |
3565 | 3570 | | |
3566 | 3571 | | |
| |||
3948 | 3953 | | |
3949 | 3954 | | |
3950 | 3955 | | |
| 3956 | + | |
| 3957 | + | |
3951 | 3958 | | |
3952 | 3959 | | |
3953 | 3960 | | |
| |||
4288 | 4295 | | |
4289 | 4296 | | |
4290 | 4297 | | |
| 4298 | + | |
| 4299 | + | |
4291 | 4300 | | |
4292 | 4301 | | |
4293 | 4302 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
284 | 284 | | |
285 | 285 | | |
286 | 286 | | |
287 | | - | |
| 287 | + | |
288 | 288 | | |
289 | 289 | | |
290 | | - | |
291 | 290 | | |
| 291 | + | |
292 | 292 | | |
293 | 293 | | |
294 | 294 | | |
| |||
435 | 435 | | |
436 | 436 | | |
437 | 437 | | |
438 | | - | |
| 438 | + | |
439 | 439 | | |
440 | 440 | | |
441 | 441 | | |
| |||
539 | 539 | | |
540 | 540 | | |
541 | 541 | | |
542 | | - | |
| 542 | + | |
543 | 543 | | |
544 | 544 | | |
545 | 545 | | |
| |||
640 | 640 | | |
641 | 641 | | |
642 | 642 | | |
643 | | - | |
| 643 | + | |
644 | 644 | | |
645 | 645 | | |
646 | 646 | | |
| |||
816 | 816 | | |
817 | 817 | | |
818 | 818 | | |
| 819 | + | |
| 820 | + | |
| 821 | + | |
819 | 822 | | |
820 | 823 | | |
821 | 824 | | |
| |||
887 | 890 | | |
888 | 891 | | |
889 | 892 | | |
890 | | - | |
| 893 | + | |
891 | 894 | | |
892 | 895 | | |
893 | 896 | | |
| |||
964 | 967 | | |
965 | 968 | | |
966 | 969 | | |
967 | | - | |
| 970 | + | |
968 | 971 | | |
969 | 972 | | |
970 | 973 | | |
| |||
974 | 977 | | |
975 | 978 | | |
976 | 979 | | |
977 | | - | |
978 | | - | |
979 | | - | |
980 | | - | |
981 | | - | |
982 | | - | |
983 | 980 | | |
984 | 981 | | |
985 | 982 | | |
| |||
0 commit comments