Commit Graph

8 Commits

Author SHA1 Message Date
Suraj Patil
3d607df8f4
fix loading flax bf16 weights in pt (#14369)
* fix loading flax bf16 weights in pt

* fix clip test

* fix t5 test

* add logging statement

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* switch back to native any

* fix check for bf16 weights

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2021-11-11 21:20:49 +05:30
Patrick von Platen
7c6cd0ac28
up (#14046) 2021-10-18 12:59:18 +02:00
Patrick von Platen
6900dded49
[Flax/JAX] Run jitted tests at every commit (#13090)
* up

* up

* up
2021-08-12 14:49:46 +02:00
Patrick von Platen
60e448c87e
[Flax] Correct pt to flax conversion if from base to head (#13006)
* finish PR

* add tests

* correct tests

* finish

* correct other flax tests

* better naming

* correct naming

* finish

* apply sylvains suggestions
2021-08-05 18:38:50 +02:00
Suraj Patil
eb881674f2
[Flax] [WIP] allow loading head model with base model weights (#12255)
* boom boom

* remove flax clip example

* allow loading head model with base model weights

* add test

* fix imports

* disable save, load test for clip

* add test_save_load_to_base
2021-06-21 15:56:42 +01:00
Suraj Patil
8d5b7f36e5
[FlaxClip] fix test from/save pretrained test (#12284)
* boom boom

* remove flax clip example

* fix from_save_pretrained
2021-06-21 15:54:34 +01:00
Vasudev Gupta
d9c0d08f9a
Flax Big Bird (#11967)
* add flax bert

* bert -> bigbird

* original_full ported

* add debugger

* init block sparse

* fix copies ; gelu_fast -> gelu_new

* block sparse port

* fix block sparse

* block sparse working

* all ckpts working

* fix-copies

* make quality

* init tests

* temporary fix for FlaxBigBirdForMultipleChoice

* skip test_attention_outputs

* fix

* gelu_fast -> gelu_new ; fix multiple choice model

* remove nsp

* fix sequence classifier

* fix

* make quality

* make fix-copies

* finish

* Delete debugger.ipynb

* Update src/transformers/models/big_bird/modeling_flax_big_bird.py

* make style

* finish

* bye bye jit flax tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2021-06-14 20:01:03 +01:00
Suraj Patil
ad25fd62bd
Add FlaxCLIP (#11883)
* add flax CLIP

* default input_shape

* add tests

* fix test

* fix name

* fix docs

* fix shapes

* attend at least 1 token

* flax conv to torch conv

* return floats

* fix equivalence tests

* fix import

* return attention_weights and update tests

* fix dosctrings

* address patricks comments

* input_shape arg

* add tests for get_image_features and get_text_features methods

* fix tests
2021-06-01 09:44:31 +05:30