Commit Graph

3 Commits

Author SHA1 Message Date
Sylvain Gugger
27d4639779
Make gradient_checkpointing a training argument (#13657)
* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
2021-09-22 07:51:38 -04:00
Patrick von Platen
2bef3433e5
[Flax] Correct all return tensors to numpy (#13307)
* fix_torch_device_generate_test

* remove @

* finish find and replace
2021-08-27 17:38:34 +02:00
Suraj Patil
7a259c190c
FlaxGPTNeo (#12493)
* flax gpt neo

* fix query scaling

* update generation test

* use flax model for test
2021-07-06 18:55:18 +05:30