From 9ade8e749931781cb4c356b97575b983906ba76a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Oct 2020 11:26:08 +0200 Subject: [PATCH] Upgrading TFAutoModelWithLMHead to (#7730) - TFAutoModelForCausalLM - TFAutoModelForMaskedLM - TFAutoModelForSeq2SeqLM as per deprecation warning. No tests as it simply removes current warnings from tests. --- src/transformers/pipelines.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 33e7efaea65..2d3d5830532 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -52,11 +52,11 @@ if is_tf_available(): TF_MODEL_WITH_LM_HEAD_MAPPING, TFAutoModel, TFAutoModelForCausalLM, + TFAutoModelForMaskedLM, TFAutoModelForQuestionAnswering, TFAutoModelForSeq2SeqLM, TFAutoModelForSequenceClassification, TFAutoModelForTokenClassification, - TFAutoModelWithLMHead, ) if is_torch_available(): @@ -2577,31 +2577,31 @@ SUPPORTED_TASKS = { }, "fill-mask": { "impl": FillMaskPipeline, - "tf": TFAutoModelWithLMHead if is_tf_available() else None, + "tf": TFAutoModelForMaskedLM if is_tf_available() else None, "pt": AutoModelForMaskedLM if is_torch_available() else None, "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}}, }, "summarization": { "impl": SummarizationPipeline, - "tf": TFAutoModelWithLMHead if is_tf_available() else None, + "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}}, }, "translation_en_to_fr": { "impl": TranslationPipeline, - "tf": TFAutoModelWithLMHead if is_tf_available() else None, + "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, }, "translation_en_to_de": { "impl": TranslationPipeline, - "tf": TFAutoModelWithLMHead if is_tf_available() else None, + "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, }, "translation_en_to_ro": { "impl": TranslationPipeline, - "tf": TFAutoModelWithLMHead if is_tf_available() else None, + "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, }, @@ -2613,7 +2613,7 @@ SUPPORTED_TASKS = { }, "text-generation": { "impl": TextGenerationPipeline, - "tf": TFAutoModelWithLMHead if is_tf_available() else None, + "tf": TFAutoModelForCausalLM if is_tf_available() else None, "pt": AutoModelForCausalLM if is_torch_available() else None, "default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, },