transformers/tests
Sanchit Gandhi e93103632b
Add bloom flax (#25094)
* First commit

* step 1 working

* add alibi

* placeholder for `scan`

* add matrix mult alibi

* beta scaling factor for bmm

* working v1 - simple forward pass

* move layer_number from attribute to arg in call

* partial functioning scan

* hacky working scan

* add more modifs

* add test

* update scan for new kwarg order

* fix position_ids problem

* fix bug in attention layer

* small fix

- do the alibi broadcasting only once

* prelim refactor

* finish refactor

* alibi shifting

* incorporate dropout_add to attention module

* make style

* make padding work again

* update

* remove bogus file

* up

* get generation to work

* clean code a bit

* added small tests

* adding albii test

* make CI tests pass:

- change init weight
- add correct tuple for output attention
- add scan test
- make CI tests work

* fix few nits

* fix nit onnx

* fix onnx nit

* add missing dtype args to nn.Modules

* remove debugging statements

* fix scan generate

* Update modeling_flax_bloom.py

* Update test_modeling_flax_bloom.py

* Update test_modeling_flax_bloom.py

* Update test_modeling_flax_bloom.py

* fix small test issue + make style

* clean up

* Update tests/models/bloom/test_modeling_flax_bloom.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* fix function name

* small fix test

* forward contrib credits from PR17761

* Fix failing test

* fix small typo documentation

* fix non passing test

- remove device from build alibi

* refactor call

- refactor `FlaxBloomBlockCollection` module

* make style

* upcast to fp32

* cleaner way to upcast

* remove unused args

* remove layer number

* fix scan test

* make style

* fix i4 casting

* fix slow test

* Update src/transformers/models/bloom/modeling_flax_bloom.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove `layer_past`

* refactor a bit

* fix `scan` slow test

* remove useless import

* major changes

- remove unused code
- refactor a bit
- revert import `torch`

* major refactoring

- change build alibi

* remove scan

* fix tests

* make style

* clean-up alibi

* add integration tests

* up

* fix batch norm conversion

* style

* style

* update pt-fx cross tests

* update copyright

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* per-weight check

* style

* line formats

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: haileyschoelkopf <haileyschoelkopf@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2023-07-27 18:24:56 +01:00
..
benchmark [Test refactor 1/5] Per-folder tests reorganization (#15725) 2022-02-23 15:46:28 -05:00
bettertransformer Add methods to PreTrainedModel to use PyTorch's BetterTransformer (#21259) 2023-04-27 11:03:42 +02:00
bnb [gpt2-int8] Add gpt2-xl int8 test (#24543) 2023-06-28 18:02:13 +02:00
deepspeed accelerate deepspeed and gradient accumulation integrate (#23236) 2023-05-31 15:16:22 +05:30
extended [tests] switch to torchrun (#22712) 2023-04-12 08:25:45 -07:00
fixtures [WIP] add SpeechT5 model (#18922) 2023-02-03 12:43:46 -05:00
generation update use_auth_token -> token (#25083) 2023-07-26 15:09:59 +02:00
models Add bloom flax (#25094) 2023-07-27 18:24:56 +01:00
optimization Make schedulers picklable by making lr_lambda fns global (#21768) 2023-03-02 12:08:43 -05:00
pipelines [Llama2] Add support for Llama 2 (#24891) 2023-07-18 15:18:31 -04:00
repo_utils Fix expected value in tests of the test fetcher (#24077) 2023-06-07 11:38:56 -04:00
sagemaker Avoid invalid escape sequences, use raw strings (#22936) 2023-04-25 09:17:56 -04:00
tokenization [ PreTrainedTokenizerFast] Keep properties from fast tokenizer (#25053) 2023-07-25 18:45:01 +02:00
tools Add support for for loops in python interpreter (#24429) 2023-06-26 09:58:14 -04:00
trainer Add dispatch_batches to training arguments (#25038) 2023-07-24 09:27:19 -04:00
utils Make (TF) CI faster (test only a subset of model classes) (#24592) 2023-06-30 16:54:54 +02:00
__init__.py GPU text generation: mMoved the encoded_prompt to correct device 2020-01-06 15:11:12 +01:00
test_backbone_common.py Add TimmBackbone model (#22619) 2023-06-06 17:11:30 +01:00
test_configuration_common.py Split common test from core tests (#24284) 2023-06-15 07:30:24 -04:00
test_configuration_utils.py update use_auth_token -> token (#25083) 2023-07-26 15:09:59 +02:00
test_feature_extraction_common.py Split common test from core tests (#24284) 2023-06-15 07:30:24 -04:00
test_feature_extraction_utils.py Split common test from core tests (#24284) 2023-06-15 07:30:24 -04:00
test_image_processing_common.py Split common test from core tests (#24284) 2023-06-15 07:30:24 -04:00
test_image_processing_utils.py Run hub tests (#24807) 2023-07-13 15:25:45 -04:00
test_image_transforms.py Bug fix - flip_channel_order for channels first images (#23701) 2023-05-31 17:12:27 +01:00
test_modeling_common.py Edit err message and comment in test_model_is_small (#25087) 2023-07-25 12:24:36 -04:00
test_modeling_flax_common.py Split common test from core tests (#24284) 2023-06-15 07:30:24 -04:00
test_modeling_flax_utils.py Split common test from core tests (#24284) 2023-06-15 07:30:24 -04:00
test_modeling_tf_common.py Speed up TF tests by reducing hidden layer counts (#24595) 2023-06-30 16:30:33 +01:00
test_modeling_tf_utils.py Split common test from core tests (#24284) 2023-06-15 07:30:24 -04:00
test_modeling_utils.py Show a warning for missing attention masks when pad_token_id is not None (#24510) 2023-06-30 08:19:39 -04:00
test_pipeline_mixin.py Update tiny models for pipeline testing. (#24364) 2023-06-20 14:43:10 +02:00
test_sequence_feature_extraction_common.py Apply ruff flake8-comprehensions (#21694) 2023-02-22 09:14:54 +01:00
test_tokenization_common.py Fix TypeError: Object of type int64 is not JSON serializable (#24340) 2023-06-27 12:15:49 +01:00
test_tokenization_utils.py Split common test from core tests (#24284) 2023-06-15 07:30:24 -04:00