transformers/tests/models/gemma
fxmarty 92abe60334
>3-5x faster torch.compile forward compilation for autoregressive decoder models (#32227)
* draft

* apply changes to all relevant archs

* rerun ci - check_docstrings.py failing?

* fix docstring

* move 2D->4D mask creation to modeling file

* repo consistency

* fix the batch size = 1 case - calling contiguous is not enough

* nit

* style

* propagate to gemma/gemma-2

* prepare inputs for gemma generation

* implement test and tiny fix in gemma2

* Update src/transformers/models/bloom/modeling_bloom.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix copies

* ci pass

* fix gemma's test_compile_static_cache tests

* flacky

* retrigger ci

---------

Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2024-08-01 02:03:07 +08:00
..
__init__.py [ gemma] Adds support for Gemma 💎 (#29167) 2024-02-21 14:21:28 +01:00
test_modeling_flax_gemma.py FIX [Gemma / CI] Make sure our runners have access to the model (#29242) 2024-02-28 06:25:23 +01:00
test_modeling_gemma.py >3-5x faster torch.compile forward compilation for autoregressive decoder models (#32227) 2024-08-01 02:03:07 +08:00
test_tokenization_gemma.py Fix slow GemmaTokenizer and improve SPM slow -> fast conversion process (#32191) 2024-07-30 23:36:38 +02:00