* Let's try autodetecting serving sigs
* Don't clobber existing sigs
* Change shapes for multiplechoice models
* Make default dummy inputs smarter too
* Fix missing f-string
* Let's YOLO a serving output too
* Read __class__.__name__ properly
* Don't just pass naked lists in there and expect it to be okay
* Code cleanup
* Update default serving sig
* Clearer error messages
* Further updates to the default serving output
* make fixup
* Update the serving output a bit more
* Cleanups and renames, raise errors appropriately when we can't infer inputs
* More renames
* we're building in a functional context again, yolo
* import DUMMY_INPUTS from the right place
* import DUMMY_INPUTS from the right place
* Support cross-attention in the dummies
* Support cross-attention in the dummies
* Complete removal of dummy/serving overrides in BERT
* Complete removal of dummy/serving overrides in RoBERTa
* Obliterate lots and lots of serving sig and dummy overrides
* merge type hint changes
* Fix for token_type_ids with vocab_size 1
* Add missing property decorator
* Fix T5 and hopefully some models that take conv inputs
* More signature pruning
* Fix T5's signature
* Fix Wav2Vec2 signature
* Fix LongformerForMultipleChoice input signature
* Fix BLIP and LED
* Better default serving output error handling
* Fix BART dummies
* Fix dummies for cross-attention, esp encoder-decoder models
* Fix visionencoderdecoder signature
* Fix BLIP serving output
* Small tweak to BART dummies
* Cleanup the ugly parameter inspection line that I used in a few places
* committed a breakpoint again
* Move the text_dims check
* Remove blip_text serving_output
* Add decoder_input_ids to the default input sig
* Remove all the manual overrides for encoder-decoder model signatures
* Tweak longformer/led input sigs
* Tweak default serving output
* output.keys() -> output
* make fixup
* Rework TF type hints to use | None instead of Optional[] for tf.Tensor
* Rework TF type hints to use | None instead of Optional[] for tf.Tensor
* Don't forget the imports
* Add the imports to tests too
* make fixup
* Refactor tests that depended on get_type_hints
* Better test refactor
* Fix an old hidden bug in the test_keras_fit input creation code
* Fix for the Deit tests
* fix past renamed to past_key_value
* update more `past`that were ski^êd
* fixup
* remove changes made to rag
* refactor `_reorder_cache` to use `past_key_values`
* fix git `prepare_inputs_for_generation` to pass tests when false is needed in use_cache
* Add a test to ensure int dummy inputs are int64
* Move the test into the existing int64 test and update a lot of existing dummies
* Fix remaining dummies
* Fix remaining dummies
* Test for int64 serving sigs as well
* Update core tests to use tf.int64
* Add better messages to the assertions
* Update all serving sigs to int64
* More sneaky hiding tf.int32s
* Add an optional int32 signature in save_pretrained
* make fixup
* Add Amy's suggestions
* Switch all serving sigs back to tf.int32
* Switch all dummies to tf.int32
* Adjust tests to check for tf.int32 instead of tf.int64
* Fix base dummy_inputs dtype
* Start casting to tf.int32 in input_processing
* Change dtype for unpack_inputs test
* Add proper tf.int32 test
* Make the alternate serving signature int64
* move generation_*.py src files into generation/*.py
* populate generation.__init__ with lazy loading
* move imports and references from generation.xxx.object to generation.object
* added test
* correct embedding init
* some changes in blenderbot (incomplete)
* update blenderbot (diff to be used as reference)
* update blenderbot_small
* update LED
* update marian
* update T5 and remove TFWrappedEmbeddings
* nullcontext() -> ContextManagers()
* fix embedding init
* Override save() to use the serving signature as the default
* Replace int32 with int64 in all our serving signatures
* Remember one very important line so as not to break every test at once
* Dtype fix for TFLED
* dtype fix for shift_tokens_right in general
* Dtype fixes in mBART and RAG
* Fix dtypes for test_unpack_inputs
* More dtype fixes
* Yet more mBART + RAG dtype fixes
* Yet more mBART + RAG dtype fixes
* Add a check that the model actually has a serving method
Comparisons like
version.parse(torch.__version__) > version.parse("1.6")
are True for torch==1.6.0+cu101 or torch==1.6.0+cpu
version.parse(version.parse(torch.__version__).base_version) are preferred (and available in pytorch_utils.py
* [Flax] Add remat (gradient checkpointing)
* fix variable naming in test
* flip: checkpoint using a method
* fix naming
* fix class naming
* apply PVP's suggestions from code review
* make fix-copies
* fix big-bird, electra, roberta
* cookie-cutter
* fix flax big-bird
* move test to common
* Use torch.finfo(self.dtype).min
* for GPTNeoX
* for Albert
* For Splinter
* Update src/transformers/models/data2vec/modeling_data2vec_audio.py
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* fix -inf used in Bart-like models
* Fix a few remaining -inf
* more fix
* clean up
* For CLIP
* For FSMT
* clean up
* fix test
* Add dtype argument and use it for LayoutLMv3
* update FlaxLongT5Attention
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* begin do_init
* add params_shape_tree
* raise error if params are accessed when do_init is False
* don't allow do_init=False when keys are missing
* make shape tree a property
* assign self._params at the end
* add test for do_init
* add do_init arg to all flax models
* fix param setting
* disbale do_init for composite models
* update test
* add do_init in FlaxBigBirdForMultipleChoice
* better names and errors
* improve test
* style
* add a warning when do_init=False
* remove extra if
* set params after _required_params
* add test for from_pretrained
* do_init => _do_init
* chage warning to info
* fix typo
* add params in init_weights
* add params to gpt neo init
* add params to init_weights
* update do_init test
* Trigger CI
* Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* update template
* trigger CI
* style
* style
* fix template
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Adding new train_step logic to make things less confusing for users
* DO NOT ASK WHY WE NEED THAT SUBCLASS
* Metrics now working, at least for single-output models with type annotations!
* Updates and TODOs for the new train_step
* Make fixup
* Temporary test workaround until T5 has types
* Temporary test workaround until T5 has types
* I think this actually works! Needs a lot of tests though
* MAke style/quality
* Revert changes to T5 tests
* Deleting the aforementioned unmentionable subclass
* Deleting the aforementioned unmentionable subclass
* Adding a Keras API test
* Style fixes
* Removing unneeded TODO and comments
* Update test_step too
* Stop trying to compute metrics with the dummy_loss, patch up test
* Make style
* make fixup
* Docstring cleanup
* make fixup
* make fixup
* Stop expanding 1D input tensors when using dummy loss
* Adjust T5 test given the new compile()
* make fixup
* Skipping test for convnext
* Removing old T5-specific Keras test now that we have a common one
* make fixup
* make fixup
* Only skip convnext test on CPU
* Update src/transformers/modeling_tf_utils.py
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* Update src/transformers/modeling_tf_utils.py
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* Avoiding TF import issues
* make fixup
* Update compile() to support TF 2.3
* Skipping model.fit() on template classes for now
* Skipping model.fit() on template class tests for now
* Replace ad-hoc solution with find_labels
* make fixup
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>