transformers/tests
Tom Aarsen 633215ba58
Generate: New Cache abstraction and Attention Sinks support (#26681)
* Draft version of new KV Caching

This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks)
/ StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented
in a third-party or in transformers directly

* Address numerous PR suggestions

1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic.
2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls.
3. Remove __bool__ and __getitem__ magic as they're confusing.
4. past_key_values.update(key, value, idx) now returns key, value.
5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR.
6. Separate key_cache and value_cache.

Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method.

* Implement the SinkCache through backward+forward rotations

* Integrate (Sink)Cache with Llama FA2

* Set use_legacy_cache=True as default, allows for test passes

* Move from/to_legacy_cache to ...Model class

* Undo unnecessary newline change

* Remove copy utility from deprecated OpenLlama

* Match import style

* manual rebase with main

* Cache class working with generate (#1)

* Draft version of new KV Caching

This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks)
/ StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented
in a third-party or in transformers directly

* Address numerous PR suggestions

1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic.
2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls.
3. Remove __bool__ and __getitem__ magic as they're confusing.
4. past_key_values.update(key, value, idx) now returns key, value.
5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR.
6. Separate key_cache and value_cache.

Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method.

* Integrate (Sink)Cache with Llama FA2

* Move from/to_legacy_cache to ...Model class

* Undo unnecessary newline change

* Match import style

* working generate

* Add tests; Simplify code; Apply changes to Mistral and Persimmon

* fix rebase mess

* a few more manual fixes

* last manual fix

* propagate changes to phi

* upgrade test

* add use_legacy_cache docstring; beef up tests

* reintroduce unwanted deletes

---------

Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>

* move import

* add default to model_kwargs.get('use_legacy_cache')

* correct failing test

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* apply PR suggestions

* fix failing test

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>

* PR comments

* tmp commit

* add docstrings

* more tests, more docstrings, add to docs

* derp

* tmp commit

* tmp dbg

* more dbg

* fix beam search bug

* cache can be a list of tuples in some models

* fix group beam search

* all but sinkcache integration tests

* fix sink cache and add hard integration test

* now also compatible with input_embeds input

* PR comments

* add Cache support to Phi+FA2

* make fixup

---------

Co-authored-by: Joao Gante <joao@huggingface.co>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-12-08 09:00:17 +01:00
..
benchmark [Test refactor 1/5] Per-folder tests reorganization (#15725) 2022-02-23 15:46:28 -05:00
bettertransformer Fixed malapropism error (#26660) 2023-10-09 11:04:57 +02:00
deepspeed device-agnostic deepspeed testing (#27342) 2023-11-09 12:34:13 +01:00
extended Device agnostic trainer testing (#27131) 2023-10-30 18:16:40 +00:00
fixtures [WIP] add SpeechT5 model (#18922) 2023-02-03 12:43:46 -05:00
fsdp device agnostic fsdp testing (#27120) 2023-11-01 07:17:06 +01:00
generation Generate: New Cache abstraction and Attention Sinks support (#26681) 2023-12-08 09:00:17 +01:00
models Fix TF loading PT safetensors when weights are tied (#27490) 2023-12-07 14:28:53 +00:00
optimization Make schedulers picklable by making lr_lambda fns global (#21768) 2023-03-02 12:08:43 -05:00
peft_integration [Peft] modules_to_save support for peft integration (#27466) 2023-11-14 10:32:57 +01:00
pipelines [Llava] Add Llava to transformers (#27662) 2023-12-07 09:30:47 +01:00
quantization Faster generation using AWQ + Fused modules (#27411) 2023-12-05 12:14:45 +01:00
repo_utils Allow # Ignore copy (#27328) 2023-12-07 10:00:08 +01:00
sagemaker Broken links fixed related to datasets docs (#27569) 2023-11-17 13:44:09 -08:00
tokenization [Styling] stylify using ruff (#27144) 2023-11-16 17:43:19 +01:00
tools Add support for for loops in python interpreter (#24429) 2023-06-26 09:58:14 -04:00
trainer Fixed passing scheduler-specific kwargs via TrainingArguments lr_scheduler_kwargs (#27595) 2023-11-28 08:33:45 +01:00
utils Update tiny model summary file (#27388) 2023-11-23 21:00:39 +01:00
__init__.py GPU text generation: mMoved the encoded_prompt to correct device 2020-01-06 15:11:12 +01:00
test_backbone_common.py [AutoBackbone] Add test (#26094) 2023-09-18 23:47:54 +02:00
test_cache_utils.py Generate: New Cache abstraction and Attention Sinks support (#26681) 2023-12-08 09:00:17 +01:00
test_configuration_common.py [ PretrainedConfig] Improve messaging (#27438) 2023-11-15 14:10:39 +01:00
test_configuration_utils.py Remove-auth-token (#27060) 2023-11-13 14:20:54 +01:00
test_feature_extraction_common.py Split common test from core tests (#24284) 2023-06-15 07:30:24 -04:00
test_feature_extraction_utils.py Remove-auth-token (#27060) 2023-11-13 14:20:54 +01:00
test_image_processing_common.py Input data format (#25464) 2023-08-16 17:45:02 +01:00
test_image_processing_utils.py Remove-auth-token (#27060) 2023-11-13 14:20:54 +01:00
test_image_transforms.py Normalize floating point cast (#27249) 2023-11-10 15:35:27 +00:00
test_modeling_common.py Generate: New Cache abstraction and Attention Sinks support (#26681) 2023-12-08 09:00:17 +01:00
test_modeling_flax_common.py Split common test from core tests (#24284) 2023-06-15 07:30:24 -04:00
test_modeling_flax_utils.py Default to msgpack for safetensors (#27460) 2023-11-13 15:17:01 +01:00
test_modeling_tf_common.py Deprecate TransfoXL (#27607) 2023-11-24 11:48:02 +01:00
test_modeling_tf_utils.py Default to msgpack for safetensors (#27460) 2023-11-13 15:17:01 +01:00
test_modeling_utils.py [ModelOnTheFlyConversionTester] Mark as slow for now (#27823) 2023-12-04 08:33:15 +01:00
test_pipeline_mixin.py Shorten the conversation tests for speed + fixing position overflows (#26960) 2023-10-31 14:20:04 +00:00
test_sequence_feature_extraction_common.py Fix typo (#25966) 2023-09-05 10:12:25 +02:00
test_tokenization_common.py [Styling] stylify using ruff (#27144) 2023-11-16 17:43:19 +01:00
test_tokenization_utils.py Remove-auth-token (#27060) 2023-11-13 14:20:54 +01:00