mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 04:28:26 +06:00
All tests are green.
This commit is contained in:
parent
e516a34a15
commit
61d9ee45e3
@ -343,8 +343,9 @@ class Pipeline(_ScikitCompat):
|
||||
if 'distilbert' not in model_type and 'xlm' not in model_type:
|
||||
args += ['token_type_ids']
|
||||
|
||||
if 'xlnet' in model_type or 'xlm' in model_type:
|
||||
args += ['cls_index', 'p_mask']
|
||||
# PR #1548 (CLI) There is an issue with attention_mask
|
||||
# if 'xlnet' in model_type or 'xlm' in model_type:
|
||||
# args += ['cls_index', 'p_mask']
|
||||
|
||||
if isinstance(features, dict):
|
||||
return {k: features[k] for k in args}
|
||||
@ -380,7 +381,7 @@ class Pipeline(_ScikitCompat):
|
||||
predictions = self.model(inputs, training=False)[0]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
predictions = self.model(**inputs).cpu()[0]
|
||||
predictions = self.model(**inputs)[0].cpu()
|
||||
|
||||
return predictions.numpy()
|
||||
|
||||
@ -444,7 +445,7 @@ class NerPipeline(Pipeline):
|
||||
|
||||
# Forward
|
||||
if is_tf_available():
|
||||
entities = self.model(**tokens)[0][0].numpy()
|
||||
entities = self.model(tokens)[0][0].numpy()
|
||||
else:
|
||||
with torch.no_grad():
|
||||
entities = self.model(**tokens)[0][0].cpu().numpy()
|
||||
|
@ -1,113 +1,164 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.tests.utils import require_tf, require_torch
|
||||
|
||||
QA_FINETUNED_MODELS = {
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad',
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad',
|
||||
'distilbert-base-uncased-distilled-squad',
|
||||
('bert-base-uncased', 'bert-large-uncased-whole-word-masking-finetuned-squad', None),
|
||||
('bert-base-cased', 'bert-large-cased-whole-word-masking-finetuned-squad', None),
|
||||
('bert-base-uncased', 'distilbert-base-uncased-distilled-squad', None)
|
||||
}
|
||||
|
||||
NER_FINETUNED_MODELS = {
|
||||
(
|
||||
'bert-base-cased',
|
||||
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin',
|
||||
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json'
|
||||
)
|
||||
}
|
||||
|
||||
FEATURE_EXTRACT_FINETUNED_MODELS = {
|
||||
('bert-base-cased', 'bert-base-cased', None),
|
||||
# ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2
|
||||
('distilbert-base-uncased', 'distilbert-base-uncased', None)
|
||||
}
|
||||
|
||||
TEXT_CLASSIF_FINETUNED_MODELS = {
|
||||
(
|
||||
'bert-base-uncased',
|
||||
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin',
|
||||
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json'
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class QuestionAnsweringPipelineTest(unittest.TestCase):
|
||||
def check_answer_structure(self, answer, batch, topk):
|
||||
self.assertIsInstance(answer, list)
|
||||
self.assertEqual(len(answer), batch)
|
||||
self.assertIsInstance(answer[0], list)
|
||||
self.assertEqual(len(answer[0]), topk)
|
||||
self.assertIsInstance(answer[0][0], dict)
|
||||
|
||||
for item in answer[0]:
|
||||
self.assertTrue('start' in item)
|
||||
self.assertTrue('end' in item)
|
||||
self.assertTrue('score' in item)
|
||||
self.assertTrue('answer' in item)
|
||||
|
||||
def question_answering_pipeline(self, nlp):
|
||||
# Simple case with topk = 1, no batching
|
||||
a = nlp(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.')
|
||||
self.check_answer_structure(a, 1, 1)
|
||||
|
||||
# Simple case with topk = 2, no batching
|
||||
a = nlp(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.', topk=2)
|
||||
self.check_answer_structure(a, 1, 2)
|
||||
|
||||
# Batch case with topk = 1
|
||||
a = nlp(question=['What is the name of the company I\'m working for ?', 'Where is the company based ?'],
|
||||
context=['I\'m working for Huggingface.', 'The company is based in New York and Paris'])
|
||||
self.check_answer_structure(a, 2, 1)
|
||||
|
||||
# Batch case with topk = 2
|
||||
a = nlp(question=['What is the name of the company I\'m working for ?', 'Where is the company based ?'],
|
||||
context=['Where is the company based ?', 'The company is based in New York and Paris'], topk=2)
|
||||
self.check_answer_structure(a, 2, 2)
|
||||
|
||||
# check for data keyword
|
||||
a = nlp(data=nlp.create_sample(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.'))
|
||||
self.check_answer_structure(a, 1, 1)
|
||||
|
||||
a = nlp(data=nlp.create_sample(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.'), topk=2)
|
||||
self.check_answer_structure(a, 1, 2)
|
||||
|
||||
a = nlp(data=[
|
||||
nlp.create_sample(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.'),
|
||||
nlp.create_sample(question='I\'m working for Huggingface.', context='The company is based in New York and Paris'),
|
||||
])
|
||||
self.check_answer_structure(a, 2, 1)
|
||||
|
||||
a = nlp(data=[
|
||||
{'question': 'What is the name of the company I\'m working for ?', 'context': 'I\'m working for Huggingface.'},
|
||||
{'question': 'Where is the company based ?', 'context': 'The company is based in New York and Paris'},
|
||||
])
|
||||
self.check_answer_structure(a, 2, 1)
|
||||
|
||||
# X keywords
|
||||
a = nlp(X=nlp.create_sample(
|
||||
question='Where is the company based ?', context='The company is based in New York and Paris'
|
||||
))
|
||||
self.check_answer_structure(a, 1, 1)
|
||||
|
||||
a = nlp(X=[
|
||||
{'question': 'What is the name of the company I\'m working for ?', 'context': 'I\'m working for Huggingface.'},
|
||||
{'question': 'Where is the company based ?', 'context': 'The company is based in New York and Paris'},
|
||||
], topk=2)
|
||||
self.check_answer_structure(a, 2, 2)
|
||||
|
||||
@patch('transformers.pipelines.is_torch_available', return_value=False)
|
||||
def test_tf_models(self, is_torch_available):
|
||||
from transformers import pipeline
|
||||
for model in QA_FINETUNED_MODELS:
|
||||
self.question_answering_pipeline(pipeline('question-answering', model))
|
||||
|
||||
@patch('transformers.pipelines.is_tf_available', return_value=False)
|
||||
@patch('transformers.tokenization_utils.is_tf_available', return_value=False)
|
||||
def test_torch_models(self, is_tf_available, _):
|
||||
from transformers import pipeline
|
||||
for model in QA_FINETUNED_MODELS:
|
||||
self.question_answering_pipeline(pipeline('question-answering', model))
|
||||
@require_tf
|
||||
def tf_pipeline(*args, **kwargs):
|
||||
return pipeline(**kwargs)
|
||||
|
||||
|
||||
class AutoPipelineTest(unittest.TestCase):
|
||||
@patch('transformers.pipelines.is_torch_available', return_value=False)
|
||||
def test_tf_qa(self, is_torch_available):
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import QuestionAnsweringPipeline
|
||||
from transformers.modeling_tf_utils import TFPreTrainedModel
|
||||
for model in QA_FINETUNED_MODELS:
|
||||
nlp = pipeline('question-answering', model)
|
||||
self.assertIsInstance(nlp, QuestionAnsweringPipeline)
|
||||
self.assertIsInstance(nlp.model, TFPreTrainedModel)
|
||||
@require_torch
|
||||
def torch_pipeline(*args, **kwargs):
|
||||
return pipeline(**kwargs)
|
||||
|
||||
@patch('transformers.pipelines.is_tf_available', return_value=False)
|
||||
def test_torch_qa(self, is_tf_available):
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import QuestionAnsweringPipeline
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
for model in QA_FINETUNED_MODELS:
|
||||
nlp = pipeline('question-answering', model)
|
||||
self.assertIsInstance(nlp, QuestionAnsweringPipeline)
|
||||
self.assertIsInstance(nlp.model, PreTrainedModel)
|
||||
|
||||
class MonoColumnInputTestCase(unittest.TestCase):
|
||||
def _test_mono_column_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
||||
self.assertIsNotNone(nlp)
|
||||
|
||||
mono_result = nlp(valid_inputs[0])
|
||||
self.assertIsInstance(mono_result, list)
|
||||
self.assertIsInstance(mono_result[0], (dict, list))
|
||||
|
||||
if isinstance(mono_result[0], list):
|
||||
mono_result = mono_result[0]
|
||||
|
||||
for key in output_keys:
|
||||
self.assertIn(key, mono_result[0])
|
||||
|
||||
multi_result = nlp(valid_inputs)
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], (dict, list))
|
||||
|
||||
if isinstance(multi_result[0], list):
|
||||
multi_result = multi_result[0]
|
||||
|
||||
for result in multi_result:
|
||||
for key in output_keys:
|
||||
self.assertIn(key, result)
|
||||
|
||||
self.assertRaises(Exception, nlp, invalid_inputs)
|
||||
|
||||
def test_ner(self):
|
||||
mandatory_keys = {'entity', 'word', 'score'}
|
||||
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
|
||||
invalid_inputs = [None]
|
||||
for tokenizer, model, config in NER_FINETUNED_MODELS:
|
||||
with patch('transformers.pipelines.is_torch_available', return_value=False):
|
||||
nlp = tf_pipeline(task='ner', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
||||
|
||||
with patch('transformers.pipelines.is_tf_available', return_value=False):
|
||||
nlp = torch_pipeline(task='ner', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
||||
|
||||
def test_sentiment_analysis(self):
|
||||
mandatory_keys = {'label'}
|
||||
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
|
||||
invalid_inputs = [None]
|
||||
for tokenizer, model, config in TEXT_CLASSIF_FINETUNED_MODELS:
|
||||
with patch('transformers.pipelines.is_torch_available', return_value=False):
|
||||
nlp = tf_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
||||
|
||||
with patch('transformers.pipelines.is_tf_available', return_value=False):
|
||||
nlp = torch_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
||||
|
||||
def test_features_extraction(self):
|
||||
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
|
||||
invalid_inputs = [None]
|
||||
for tokenizer, model, config in FEATURE_EXTRACT_FINETUNED_MODELS:
|
||||
with patch('transformers.pipelines.is_torch_available', return_value=False):
|
||||
nlp = tf_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
|
||||
|
||||
with patch('transformers.pipelines.is_tf_available', return_value=False):
|
||||
nlp = torch_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
|
||||
|
||||
|
||||
class MultiColumnInputTestCase(unittest.TestCase):
|
||||
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
||||
self.assertIsNotNone(nlp)
|
||||
|
||||
mono_result = nlp(valid_inputs[0])
|
||||
self.assertIsInstance(mono_result, dict)
|
||||
|
||||
for key in output_keys:
|
||||
self.assertIn(key, mono_result)
|
||||
|
||||
multi_result = nlp(valid_inputs)
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], dict)
|
||||
|
||||
for result in multi_result:
|
||||
for key in output_keys:
|
||||
self.assertIn(key, result)
|
||||
|
||||
self.assertRaises(Exception, nlp, invalid_inputs[0])
|
||||
self.assertRaises(Exception, nlp, invalid_inputs)
|
||||
|
||||
def test_question_answering(self):
|
||||
mandatory_output_keys = {'score', 'answer', 'start', 'end'}
|
||||
valid_samples = [
|
||||
{'question': 'Where was HuggingFace founded ?', 'context': 'HuggingFace was founded in Paris.'},
|
||||
{
|
||||
'question': 'In what field is HuggingFace working ?',
|
||||
'context': 'HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.'
|
||||
}
|
||||
]
|
||||
invalid_samples = [
|
||||
{'question': '', 'context': 'This is a test to try empty question edge case'},
|
||||
{'question': None, 'context': 'This is a test to try empty question edge case'},
|
||||
{'question': 'What is does with empty context ?', 'context': ''},
|
||||
{'question': 'What is does with empty context ?', 'context': None},
|
||||
]
|
||||
|
||||
for tokenizer, model, config in QA_FINETUNED_MODELS:
|
||||
|
||||
# Test for Tensorflow
|
||||
with patch('transformers.pipelines.is_torch_available', return_value=False):
|
||||
nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
|
||||
|
||||
# Test for PyTorch
|
||||
with patch('transformers.pipelines.is_tf_available', return_value=False):
|
||||
nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user