Add new AutoModel classes in pipeline (#6062)

* use new AutoModel classed

* make style and quality
This commit is contained in:
Suraj Patil 2020-07-27 21:20:08 +05:30 committed by GitHub
parent 5779e5434d
commit c8bdf7f4ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -60,13 +60,14 @@ if is_torch_available():
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
MODEL_WITH_LM_HEAD_MAPPING, AutoModelForCausalLM,
AutoModelForMaskedLM,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1029,7 +1030,7 @@ class FillMaskPipeline(Pipeline):
task=task, task=task,
) )
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING) self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
self.topk = topk self.topk = topk
@ -1817,7 +1818,9 @@ class TranslationPipeline(Pipeline):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING) self.check_model_type(
TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
)
def __call__( def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
@ -1933,7 +1936,7 @@ SUPPORTED_TASKS = {
"fill-mask": { "fill-mask": {
"impl": FillMaskPipeline, "impl": FillMaskPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None, "tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None, "pt": AutoModelForMaskedLM if is_torch_available() else None,
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}}, "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
}, },
"summarization": { "summarization": {
@ -1945,25 +1948,25 @@ SUPPORTED_TASKS = {
"translation_en_to_fr": { "translation_en_to_fr": {
"impl": TranslationPipeline, "impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None, "tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None, "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
}, },
"translation_en_to_de": { "translation_en_to_de": {
"impl": TranslationPipeline, "impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None, "tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None, "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
}, },
"translation_en_to_ro": { "translation_en_to_ro": {
"impl": TranslationPipeline, "impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None, "tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None, "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
}, },
"text-generation": { "text-generation": {
"impl": TextGenerationPipeline, "impl": TextGenerationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None, "tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None, "pt": AutoModelForCausalLM if is_torch_available() else None,
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
}, },
"zero-shot-classification": { "zero-shot-classification": {