Sanchit Gandhi
|
e93103632b
|
Add bloom flax (#25094)
* First commit
* step 1 working
* add alibi
* placeholder for `scan`
* add matrix mult alibi
* beta scaling factor for bmm
* working v1 - simple forward pass
* move layer_number from attribute to arg in call
* partial functioning scan
* hacky working scan
* add more modifs
* add test
* update scan for new kwarg order
* fix position_ids problem
* fix bug in attention layer
* small fix
- do the alibi broadcasting only once
* prelim refactor
* finish refactor
* alibi shifting
* incorporate dropout_add to attention module
* make style
* make padding work again
* update
* remove bogus file
* up
* get generation to work
* clean code a bit
* added small tests
* adding albii test
* make CI tests pass:
- change init weight
- add correct tuple for output attention
- add scan test
- make CI tests work
* fix few nits
* fix nit onnx
* fix onnx nit
* add missing dtype args to nn.Modules
* remove debugging statements
* fix scan generate
* Update modeling_flax_bloom.py
* Update test_modeling_flax_bloom.py
* Update test_modeling_flax_bloom.py
* Update test_modeling_flax_bloom.py
* fix small test issue + make style
* clean up
* Update tests/models/bloom/test_modeling_flax_bloom.py
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* fix function name
* small fix test
* forward contrib credits from PR17761
* Fix failing test
* fix small typo documentation
* fix non passing test
- remove device from build alibi
* refactor call
- refactor `FlaxBloomBlockCollection` module
* make style
* upcast to fp32
* cleaner way to upcast
* remove unused args
* remove layer number
* fix scan test
* make style
* fix i4 casting
* fix slow test
* Update src/transformers/models/bloom/modeling_flax_bloom.py
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* remove `layer_past`
* refactor a bit
* fix `scan` slow test
* remove useless import
* major changes
- remove unused code
- refactor a bit
- revert import `torch`
* major refactoring
- change build alibi
* remove scan
* fix tests
* make style
* clean-up alibi
* add integration tests
* up
* fix batch norm conversion
* style
* style
* update pt-fx cross tests
* update copyright
* Update src/transformers/modeling_flax_pytorch_utils.py
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* per-weight check
* style
* line formats
---------
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: haileyschoelkopf <haileyschoelkopf@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
|
2023-07-27 18:24:56 +01:00 |
|