Upgrading TFAutoModelWithLMHead to (#7730)

- TFAutoModelForCausalLM
- TFAutoModelForMaskedLM
- TFAutoModelForSeq2SeqLM

as per deprecation warning. No tests as it simply removes current
warnings from tests.
This commit is contained in:
Nicolas Patry 2020-10-15 11:26:08 +02:00 committed by GitHub
parent 62b5622e6b
commit 9ade8e7499
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,11 +52,11 @@ if is_tf_available():
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForMaskedLM,
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
)
if is_torch_available():
@ -2577,31 +2577,31 @@ SUPPORTED_TASKS = {
},
"fill-mask": {
"impl": FillMaskPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"tf": TFAutoModelForMaskedLM if is_tf_available() else None,
"pt": AutoModelForMaskedLM if is_torch_available() else None,
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
},
"summarization": {
"impl": SummarizationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
},
"translation_en_to_fr": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"translation_en_to_de": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"translation_en_to_ro": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
@ -2613,7 +2613,7 @@ SUPPORTED_TASKS = {
},
"text-generation": {
"impl": TextGenerationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
"pt": AutoModelForCausalLM if is_torch_available() else None,
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
},