mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Moving summarization
pipeline to new testing format. (#13279)
* Moving `summarization` pipeline to new testing format. * Remove generate_kwargs from __init__ args.
This commit is contained in:
parent
55fb88d369
commit
879fe8fa75
@ -110,6 +110,7 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
||||||
-- The token ids of the generated text.
|
-- The token ids of the generated text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||||
|
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
@ -267,7 +268,7 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
|||||||
def _parse_and_tokenize(self, *args, src_lang, tgt_lang, truncation):
|
def _parse_and_tokenize(self, *args, src_lang, tgt_lang, truncation):
|
||||||
if getattr(self.tokenizer, "_build_translation_inputs", None):
|
if getattr(self.tokenizer, "_build_translation_inputs", None):
|
||||||
return self.tokenizer._build_translation_inputs(
|
return self.tokenizer._build_translation_inputs(
|
||||||
*args, src_lang=src_lang, tgt_lang=tgt_lang, truncation=truncation
|
*args, return_tensors=self.framework, src_lang=src_lang, tgt_lang=tgt_lang, truncation=truncation
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return super()._parse_and_tokenize(*args, truncation=truncation)
|
return super()._parse_and_tokenize(*args, truncation=truncation)
|
||||||
|
@ -14,84 +14,74 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoTokenizer, is_torch_available, pipeline
|
from transformers import (
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
|
LEDConfig,
|
||||||
|
SummarizationPipeline,
|
||||||
|
T5Config,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow, torch_device
|
||||||
from transformers.tokenization_utils import TruncationStrategy
|
from transformers.tokenization_utils import TruncationStrategy
|
||||||
|
|
||||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from transformers.models.bart import BartConfig, BartForConditionalGeneration
|
|
||||||
|
|
||||||
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||||
|
|
||||||
|
|
||||||
class SimpleSummarizationPipelineTests(unittest.TestCase):
|
@is_pipeline_test
|
||||||
@require_torch
|
class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
def test_input_too_long(self):
|
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||||
torch.manual_seed(0)
|
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||||
config = BartConfig(
|
|
||||||
vocab_size=257,
|
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||||
d_model=32,
|
summarizer = SummarizationPipeline(model=model, tokenizer=tokenizer)
|
||||||
encoder_layers=1,
|
|
||||||
decoder_layers=1,
|
outputs = summarizer("(CNN)The Palestinian Authority officially became")
|
||||||
encoder_ffn_dim=32,
|
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
|
||||||
decoder_ffn_dim=32,
|
|
||||||
# So any text > 4 should raise an exception
|
outputs = summarizer(
|
||||||
max_position_embeddings=4,
|
"(CNN)The Palestinian Authority officially became ",
|
||||||
encoder_attention_heads=1,
|
num_beams=2,
|
||||||
decoder_attention_heads=1,
|
min_length=2,
|
||||||
max_length=4,
|
max_length=5,
|
||||||
min_length=1,
|
|
||||||
forced_eos_token_id=None,
|
|
||||||
)
|
)
|
||||||
model = BartForConditionalGeneration(config)
|
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
|
||||||
# Bias output towards L
|
|
||||||
V, C = model.lm_head.weight.shape
|
|
||||||
|
|
||||||
bias = torch.zeros(V)
|
if not isinstance(model.config, (T5Config, LEDConfig)):
|
||||||
bias[76] = 10
|
# LED, T5 can handle it.
|
||||||
|
# Too long.
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
outputs = summarizer("This " * 1000)
|
||||||
|
outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST)
|
||||||
|
|
||||||
model.lm_head.bias = nn.Parameter(bias)
|
@require_torch
|
||||||
|
def test_small_model_pt(self):
|
||||||
|
summarizer = pipeline(task="summarization", model="sshleifer/tiny-mbart", framework="pt")
|
||||||
|
outputs = summarizer("This is a small test")
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"summary_text": "เข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไป"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# # Generated with:
|
@require_tf
|
||||||
# import tempfile
|
def test_small_model_tf(self):
|
||||||
# from tokenizers import Tokenizer, models
|
summarizer = pipeline(task="summarization", model="sshleifer/tiny-mbart", framework="tf")
|
||||||
# from transformers import PreTrainedTokenizerFast
|
outputs = summarizer("This is a small test")
|
||||||
# model_max_length = 4
|
self.assertEqual(
|
||||||
# vocab = [(chr(i), i) for i in range(256)]
|
outputs,
|
||||||
# tokenizer = Tokenizer(models.Unigram(vocab))
|
[
|
||||||
# with tempfile.NamedTemporaryFile() as f:
|
{
|
||||||
# tokenizer.save(f.name)
|
"summary_text": "เข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไป"
|
||||||
# real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, model_max_length=model_max_length)
|
}
|
||||||
# real_tokenizer._tokenizer.save("tokenizer.json")
|
],
|
||||||
# # + add missing config.json with albert as model_type
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Narsil/small_summarization_test")
|
|
||||||
summarizer = pipeline(task="summarization", model=model, tokenizer=tokenizer)
|
|
||||||
|
|
||||||
with self.assertLogs("transformers", level="WARNING"):
|
|
||||||
with self.assertRaises(IndexError):
|
|
||||||
_ = summarizer("This is a test")
|
|
||||||
|
|
||||||
output = summarizer("This is a test", truncation=TruncationStrategy.ONLY_FIRST)
|
|
||||||
# 2 is default BOS from Bart.
|
|
||||||
self.assertEqual(output, [{"summary_text": "\x02 L L L"}])
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|
||||||
pipeline_task = "summarization"
|
|
||||||
pipeline_running_kwargs = {"num_beams": 2, "min_length": 2, "max_length": 5}
|
|
||||||
small_models = [
|
|
||||||
"patrickvonplaten/t5-tiny-random",
|
|
||||||
"sshleifer/bart-tiny-random",
|
|
||||||
] # Models tested without the @slow decorator
|
|
||||||
large_models = [] # Models tested with the @slow decorator
|
|
||||||
invalid_inputs = [4, "<mask>"]
|
|
||||||
mandatory_keys = ["summary_text"]
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
|
Loading…
Reference in New Issue
Block a user