Commit Graph

4 Commits

Author SHA1 Message Date
Yih-Dar
4cdb67caba
Use cross_attention_hidden_size in Encoder-Decoder models (#14378)
* add cross_attention_hidden_size to text-2-text encoder-decoder models (PT/Flax)

* for TFEncoderDecoderModel

* add equivalence test for TFEncoderDecoderModel

* fix

* fix failed equivalence tests

* remove unused import

* add detailed comment

* Fix check_equivalence_tf_to_pt by using encoder/decoder

* cleaning

* Use cross_attention_hidden_size in speech-to-text

* clean fast init logging msg in encoder decoder models

* increase tol from 1e-5 to 1e-3 for tf test

* style

* style

* make sure projection layer can run

* remove type conversion + add check

* fix conflict (config.output_hidden_size)

* Remove TF -> PT in check_pt_tf_equivalence for TFEncoderDecoderModel

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2021-12-07 00:27:32 +01:00
Yih-Dar
95b3ec3bc9
Add FlaxVisionEncoderDecoderModel (#13359)
* Start the work on FlaxVisionEncoderDecoderModel

* Add FlaxVisionEncoderDecoderModel

* Add VisionEncoderDecoderConfig

* Make FlaxVisionEncoderDecoderModel visible to transformers

* Add test

* Fix wrong getattr usage

* Fix tests

* Add FlaxAutoModelForVision2Seq

* Expose FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING

* clean-up

* add integration test

* update expected logits

* update expected scores

* Add ViT2GPT2ModelIntegrationTest + some cleaning

* Add projection layer + PT/Flax equivalence tests

* Fix import

* minor changes

* make test slow again

* Apply suggestions

* Add modeling_flax_vision_encoder_decoder to _ignore_modules in get_model_modules()

* fix copies

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* split long strings in multiple lines

* decoder_input_ids can't be None

* Add back test_configuration_tie

* Remove attention_mask parameter

* fix test - encoder_last_hidden_state should be encoder_outputs.last_hidden_state instead of the projected vector

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Remove more encoder_attention_mask

* remove encoder_attention_mask when calling self.decode (in FlaxVisionEncoderDecoderModule)

* Fix style + pass 1s instead of None as encoder_attention_mask

* fix init_weights

* pass None for encoder_attention_mask

* pass 1s instead of None as encoder_attention_mask

* Fix doc style

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2021-11-09 15:14:28 +05:30
Yih-Dar
4a394cf53f
Fix test_configuration_tie in FlaxEncoderDecoderModelTest (#14076)
* check test_configuration_tie

* Fix test_configuration_tie

* make test slow again

* Remove property and use model.module.bind

* revert to slow test

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2021-11-02 15:32:41 +05:30
Yih-Dar
2e20c0f34a
Make Flax GPT2 working with cross attention (#13008)
* make flax gpt2 working with cross attention

* Remove encoder->decoder projection layer

* A draft (incomplete) for FlaxEncoderDecoderModel

* Add the method from_encoder_decoder_pretrained + the docstrings

* Fix the mistakes of using EncoderDecoderModel

* Fix style

* Add FlaxEncoderDecoderModel to the library

* Fix cyclic imports

* Add FlaxEncoderDecoderModel to modeling_flax_auto.py

* Remove question comments

* add tests for FlaxEncoderDecoderModel

* add flax_encoder_decoder to the lists of ignored entries in check_repo.py

* fix missing required positional arguments

* Remove **kwargs when creating FlaxEncoderDecoderModel in from_encoder_decoder_pretrained()

Also fix generation eos/pad tokens issue

* Fix: Use sequences from the generated_output

* Change a check from assert to raise ValueError

* Fix examples and token ids issues

* Fix missing all_cross_attentions when outputting tuple in modeling_gpt2

* Remove the changes in configuration docstrings.

* allow for bert 2 gpt2

* make fix-copies

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Change remaining examples to bert2gpt2

* Change the test to Bert2GPT2

* Fix examples

* Fix import

* Fix unpack bug

* Rename to FlaxEncoderDecoderModelTest and change the test to bert2gpt2

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix: NotImplentedError -> NotImplementedError

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* up

* finalize

Co-authored-by: ydshieh <ydshieh@user.noreply>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2021-08-23 17:57:29 +02:00