mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-12 09:10:05 +06:00
91ab02af28
1 Commits
Author | SHA1 | Message | Date | |
---|---|---|---|---|
![]() |
75627148ee
|
Flax Masked Language Modeling training example (#8728)
* Remove "Model" suffix from Flax models to look more 🤗 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Initial working (forward + backward) for Flax MLM training example. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Simply code Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Addressing comments, using module and moving to LM task. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Restore parameter name "module" wrongly renamed model. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Restore correct output ordering... Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Actually commit the example 😅 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Add FlaxBertModelForMaskedLM after rebasing. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make it possible to initialize the training from scratch Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reuse flax linen example of cross entropy loss Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added specific data collator for flax Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove todo for data collator Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added evaluation step Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added ability to provide dtype to support bfloat16 on TPU Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable flax tensorboard output Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable jax.pmap support. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure batches are correctly sized to be dispatched with jax.pmap Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable bfloat16 with --fp16 cmdline args Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correctly export metrics to tensorboard Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added dropout and ability to use it. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Effectively enable & disable during training and evaluation steps. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Oops. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable specifying kernel initializer scale Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added warmup step to the learning rate scheduler. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix typo. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Print training loss Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix linter issue (flake8) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix model matching Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix dummies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix non default dtype on Flax models Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same create_position_ids_from_input_ids for FlaxRoberta Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta attention as Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix copy Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Wording. Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Marc van Zee <marcvanzee@gmail.com> |