diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 10c682d0a96..16589469256 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -60,13 +60,14 @@ if is_torch_available(): AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForTokenClassification, - AutoModelWithLMHead, AutoModelForSeq2SeqLM, - MODEL_WITH_LM_HEAD_MAPPING, + AutoModelForCausalLM, + AutoModelForMaskedLM, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, ) if TYPE_CHECKING: @@ -1029,7 +1030,7 @@ class FillMaskPipeline(Pipeline): task=task, ) - self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING) + self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING) self.topk = topk @@ -1817,7 +1818,9 @@ class TranslationPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING) + self.check_model_type( + TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + ) def __call__( self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs @@ -1933,7 +1936,7 @@ SUPPORTED_TASKS = { "fill-mask": { "impl": FillMaskPipeline, "tf": TFAutoModelWithLMHead if is_tf_available() else None, - "pt": AutoModelWithLMHead if is_torch_available() else None, + "pt": AutoModelForMaskedLM if is_torch_available() else None, "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}}, }, "summarization": { @@ -1945,25 +1948,25 @@ SUPPORTED_TASKS = { "translation_en_to_fr": { "impl": TranslationPipeline, "tf": TFAutoModelWithLMHead if is_tf_available() else None, - "pt": AutoModelWithLMHead if is_torch_available() else None, + "pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, }, "translation_en_to_de": { "impl": TranslationPipeline, "tf": TFAutoModelWithLMHead if is_tf_available() else None, - "pt": AutoModelWithLMHead if is_torch_available() else None, + "pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, }, "translation_en_to_ro": { "impl": TranslationPipeline, "tf": TFAutoModelWithLMHead if is_tf_available() else None, - "pt": AutoModelWithLMHead if is_torch_available() else None, + "pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, }, "text-generation": { "impl": TextGenerationPipeline, "tf": TFAutoModelWithLMHead if is_tf_available() else None, - "pt": AutoModelWithLMHead if is_torch_available() else None, + "pt": AutoModelForCausalLM if is_torch_available() else None, "default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, }, "zero-shot-classification": {