mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Add new AutoModel classes in pipeline (#6062)
* use new AutoModel classed * make style and quality
This commit is contained in:
parent
5779e5434d
commit
c8bdf7f4ec
@ -60,13 +60,14 @@ if is_torch_available():
|
|||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelWithLMHead,
|
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForMaskedLM,
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -1029,7 +1030,7 @@ class FillMaskPipeline(Pipeline):
|
|||||||
task=task,
|
task=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING)
|
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
|
||||||
|
|
||||||
self.topk = topk
|
self.topk = topk
|
||||||
|
|
||||||
@ -1817,7 +1818,9 @@ class TranslationPipeline(Pipeline):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING)
|
self.check_model_type(
|
||||||
|
TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||||
@ -1933,7 +1936,7 @@ SUPPORTED_TASKS = {
|
|||||||
"fill-mask": {
|
"fill-mask": {
|
||||||
"impl": FillMaskPipeline,
|
"impl": FillMaskPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelForMaskedLM if is_torch_available() else None,
|
||||||
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
|
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
|
||||||
},
|
},
|
||||||
"summarization": {
|
"summarization": {
|
||||||
@ -1945,25 +1948,25 @@ SUPPORTED_TASKS = {
|
|||||||
"translation_en_to_fr": {
|
"translation_en_to_fr": {
|
||||||
"impl": TranslationPipeline,
|
"impl": TranslationPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
||||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||||
},
|
},
|
||||||
"translation_en_to_de": {
|
"translation_en_to_de": {
|
||||||
"impl": TranslationPipeline,
|
"impl": TranslationPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
||||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||||
},
|
},
|
||||||
"translation_en_to_ro": {
|
"translation_en_to_ro": {
|
||||||
"impl": TranslationPipeline,
|
"impl": TranslationPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
||||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||||
},
|
},
|
||||||
"text-generation": {
|
"text-generation": {
|
||||||
"impl": TextGenerationPipeline,
|
"impl": TextGenerationPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelForCausalLM if is_torch_available() else None,
|
||||||
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
|
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
|
||||||
},
|
},
|
||||||
"zero-shot-classification": {
|
"zero-shot-classification": {
|
||||||
|
Loading…
Reference in New Issue
Block a user