Commit Graph

4 Commits

Author SHA1 Message Date
Suraj Patil
eb881674f2
[Flax] [WIP] allow loading head model with base model weights (#12255)
* boom boom

* remove flax clip example

* allow loading head model with base model weights

* add test

* fix imports

* disable save, load test for clip

* add test_save_load_to_base
2021-06-21 15:56:42 +01:00
Suraj Patil
8d5b7f36e5
[FlaxClip] fix test from/save pretrained test (#12284)
* boom boom

* remove flax clip example

* fix from_save_pretrained
2021-06-21 15:54:34 +01:00
Vasudev Gupta
d9c0d08f9a
Flax Big Bird (#11967)
* add flax bert

* bert -> bigbird

* original_full ported

* add debugger

* init block sparse

* fix copies ; gelu_fast -> gelu_new

* block sparse port

* fix block sparse

* block sparse working

* all ckpts working

* fix-copies

* make quality

* init tests

* temporary fix for FlaxBigBirdForMultipleChoice

* skip test_attention_outputs

* fix

* gelu_fast -> gelu_new ; fix multiple choice model

* remove nsp

* fix sequence classifier

* fix

* make quality

* make fix-copies

* finish

* Delete debugger.ipynb

* Update src/transformers/models/big_bird/modeling_flax_big_bird.py

* make style

* finish

* bye bye jit flax tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2021-06-14 20:01:03 +01:00
Suraj Patil
ad25fd62bd
Add FlaxCLIP (#11883)
* add flax CLIP

* default input_shape

* add tests

* fix test

* fix name

* fix docs

* fix shapes

* attend at least 1 token

* flax conv to torch conv

* return floats

* fix equivalence tests

* fix import

* return attention_weights and update tests

* fix dosctrings

* address patricks comments

* input_shape arg

* add tests for get_image_features and get_text_features methods

* fix tests
2021-06-01 09:44:31 +05:30