All tests are green.

This commit is contained in:
Morgan Funtowicz 2019-12-20 11:47:56 +01:00
parent e516a34a15
commit 61d9ee45e3
2 changed files with 154 additions and 102 deletions

View File

@ -343,8 +343,9 @@ class Pipeline(_ScikitCompat):
if 'distilbert' not in model_type and 'xlm' not in model_type:
args += ['token_type_ids']
if 'xlnet' in model_type or 'xlm' in model_type:
args += ['cls_index', 'p_mask']
# PR #1548 (CLI) There is an issue with attention_mask
# if 'xlnet' in model_type or 'xlm' in model_type:
# args += ['cls_index', 'p_mask']
if isinstance(features, dict):
return {k: features[k] for k in args}
@ -380,7 +381,7 @@ class Pipeline(_ScikitCompat):
predictions = self.model(inputs, training=False)[0]
else:
with torch.no_grad():
predictions = self.model(**inputs).cpu()[0]
predictions = self.model(**inputs)[0].cpu()
return predictions.numpy()
@ -444,7 +445,7 @@ class NerPipeline(Pipeline):
# Forward
if is_tf_available():
entities = self.model(**tokens)[0][0].numpy()
entities = self.model(tokens)[0][0].numpy()
else:
with torch.no_grad():
entities = self.model(**tokens)[0][0].cpu().numpy()

View File

@ -1,113 +1,164 @@
import unittest
from unittest.mock import patch
from typing import Iterable
from transformers import pipeline
from transformers.tests.utils import require_tf, require_torch
QA_FINETUNED_MODELS = {
'bert-large-uncased-whole-word-masking-finetuned-squad',
'bert-large-cased-whole-word-masking-finetuned-squad',
'distilbert-base-uncased-distilled-squad',
('bert-base-uncased', 'bert-large-uncased-whole-word-masking-finetuned-squad', None),
('bert-base-cased', 'bert-large-cased-whole-word-masking-finetuned-squad', None),
('bert-base-uncased', 'distilbert-base-uncased-distilled-squad', None)
}
NER_FINETUNED_MODELS = {
(
'bert-base-cased',
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin',
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json'
)
}
FEATURE_EXTRACT_FINETUNED_MODELS = {
('bert-base-cased', 'bert-base-cased', None),
# ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2
('distilbert-base-uncased', 'distilbert-base-uncased', None)
}
TEXT_CLASSIF_FINETUNED_MODELS = {
(
'bert-base-uncased',
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin',
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json'
)
}
class QuestionAnsweringPipelineTest(unittest.TestCase):
def check_answer_structure(self, answer, batch, topk):
self.assertIsInstance(answer, list)
self.assertEqual(len(answer), batch)
self.assertIsInstance(answer[0], list)
self.assertEqual(len(answer[0]), topk)
self.assertIsInstance(answer[0][0], dict)
for item in answer[0]:
self.assertTrue('start' in item)
self.assertTrue('end' in item)
self.assertTrue('score' in item)
self.assertTrue('answer' in item)
def question_answering_pipeline(self, nlp):
# Simple case with topk = 1, no batching
a = nlp(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.')
self.check_answer_structure(a, 1, 1)
# Simple case with topk = 2, no batching
a = nlp(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.', topk=2)
self.check_answer_structure(a, 1, 2)
# Batch case with topk = 1
a = nlp(question=['What is the name of the company I\'m working for ?', 'Where is the company based ?'],
context=['I\'m working for Huggingface.', 'The company is based in New York and Paris'])
self.check_answer_structure(a, 2, 1)
# Batch case with topk = 2
a = nlp(question=['What is the name of the company I\'m working for ?', 'Where is the company based ?'],
context=['Where is the company based ?', 'The company is based in New York and Paris'], topk=2)
self.check_answer_structure(a, 2, 2)
# check for data keyword
a = nlp(data=nlp.create_sample(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.'))
self.check_answer_structure(a, 1, 1)
a = nlp(data=nlp.create_sample(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.'), topk=2)
self.check_answer_structure(a, 1, 2)
a = nlp(data=[
nlp.create_sample(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.'),
nlp.create_sample(question='I\'m working for Huggingface.', context='The company is based in New York and Paris'),
])
self.check_answer_structure(a, 2, 1)
a = nlp(data=[
{'question': 'What is the name of the company I\'m working for ?', 'context': 'I\'m working for Huggingface.'},
{'question': 'Where is the company based ?', 'context': 'The company is based in New York and Paris'},
])
self.check_answer_structure(a, 2, 1)
# X keywords
a = nlp(X=nlp.create_sample(
question='Where is the company based ?', context='The company is based in New York and Paris'
))
self.check_answer_structure(a, 1, 1)
a = nlp(X=[
{'question': 'What is the name of the company I\'m working for ?', 'context': 'I\'m working for Huggingface.'},
{'question': 'Where is the company based ?', 'context': 'The company is based in New York and Paris'},
], topk=2)
self.check_answer_structure(a, 2, 2)
@patch('transformers.pipelines.is_torch_available', return_value=False)
def test_tf_models(self, is_torch_available):
from transformers import pipeline
for model in QA_FINETUNED_MODELS:
self.question_answering_pipeline(pipeline('question-answering', model))
@patch('transformers.pipelines.is_tf_available', return_value=False)
@patch('transformers.tokenization_utils.is_tf_available', return_value=False)
def test_torch_models(self, is_tf_available, _):
from transformers import pipeline
for model in QA_FINETUNED_MODELS:
self.question_answering_pipeline(pipeline('question-answering', model))
@require_tf
def tf_pipeline(*args, **kwargs):
return pipeline(**kwargs)
class AutoPipelineTest(unittest.TestCase):
@patch('transformers.pipelines.is_torch_available', return_value=False)
def test_tf_qa(self, is_torch_available):
from transformers import pipeline
from transformers.pipelines import QuestionAnsweringPipeline
from transformers.modeling_tf_utils import TFPreTrainedModel
for model in QA_FINETUNED_MODELS:
nlp = pipeline('question-answering', model)
self.assertIsInstance(nlp, QuestionAnsweringPipeline)
self.assertIsInstance(nlp.model, TFPreTrainedModel)
@require_torch
def torch_pipeline(*args, **kwargs):
return pipeline(**kwargs)
@patch('transformers.pipelines.is_tf_available', return_value=False)
def test_torch_qa(self, is_tf_available):
from transformers import pipeline
from transformers.pipelines import QuestionAnsweringPipeline
from transformers.modeling_utils import PreTrainedModel
for model in QA_FINETUNED_MODELS:
nlp = pipeline('question-answering', model)
self.assertIsInstance(nlp, QuestionAnsweringPipeline)
self.assertIsInstance(nlp.model, PreTrainedModel)
class MonoColumnInputTestCase(unittest.TestCase):
def _test_mono_column_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
self.assertIsNotNone(nlp)
mono_result = nlp(valid_inputs[0])
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], (dict, list))
if isinstance(mono_result[0], list):
mono_result = mono_result[0]
for key in output_keys:
self.assertIn(key, mono_result[0])
multi_result = nlp(valid_inputs)
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))
if isinstance(multi_result[0], list):
multi_result = multi_result[0]
for result in multi_result:
for key in output_keys:
self.assertIn(key, result)
self.assertRaises(Exception, nlp, invalid_inputs)
def test_ner(self):
mandatory_keys = {'entity', 'word', 'score'}
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
invalid_inputs = [None]
for tokenizer, model, config in NER_FINETUNED_MODELS:
with patch('transformers.pipelines.is_torch_available', return_value=False):
nlp = tf_pipeline(task='ner', model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
with patch('transformers.pipelines.is_tf_available', return_value=False):
nlp = torch_pipeline(task='ner', model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
def test_sentiment_analysis(self):
mandatory_keys = {'label'}
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
invalid_inputs = [None]
for tokenizer, model, config in TEXT_CLASSIF_FINETUNED_MODELS:
with patch('transformers.pipelines.is_torch_available', return_value=False):
nlp = tf_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
with patch('transformers.pipelines.is_tf_available', return_value=False):
nlp = torch_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
def test_features_extraction(self):
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
invalid_inputs = [None]
for tokenizer, model, config in FEATURE_EXTRACT_FINETUNED_MODELS:
with patch('transformers.pipelines.is_torch_available', return_value=False):
nlp = tf_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
with patch('transformers.pipelines.is_tf_available', return_value=False):
nlp = torch_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
self.assertIsNotNone(nlp)
mono_result = nlp(valid_inputs[0])
self.assertIsInstance(mono_result, dict)
for key in output_keys:
self.assertIn(key, mono_result)
multi_result = nlp(valid_inputs)
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], dict)
for result in multi_result:
for key in output_keys:
self.assertIn(key, result)
self.assertRaises(Exception, nlp, invalid_inputs[0])
self.assertRaises(Exception, nlp, invalid_inputs)
def test_question_answering(self):
mandatory_output_keys = {'score', 'answer', 'start', 'end'}
valid_samples = [
{'question': 'Where was HuggingFace founded ?', 'context': 'HuggingFace was founded in Paris.'},
{
'question': 'In what field is HuggingFace working ?',
'context': 'HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.'
}
]
invalid_samples = [
{'question': '', 'context': 'This is a test to try empty question edge case'},
{'question': None, 'context': 'This is a test to try empty question edge case'},
{'question': 'What is does with empty context ?', 'context': ''},
{'question': 'What is does with empty context ?', 'context': None},
]
for tokenizer, model, config in QA_FINETUNED_MODELS:
# Test for Tensorflow
with patch('transformers.pipelines.is_torch_available', return_value=False):
nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer)
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
# Test for PyTorch
with patch('transformers.pipelines.is_tf_available', return_value=False):
nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer)
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
if __name__ == '__main__':