Generate: TF can now generate from embeddings in encoder-decoder models (#21475)

This commit is contained in:
Joao Gante 2023-02-07 11:18:23 +00:00 committed by GitHub
parent 12eb528b5a
commit 1e4cf8bb44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 184 additions and 197 deletions

View File

@ -664,9 +664,11 @@ class TFGenerationMixin:
)
# 4. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, generation_config.bos_token_id)
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
input_ids, generation_config.bos_token_id, model_kwargs
)
# inputs_ids now has to be defined and cannot be None anymore
batch_size = shape_list(input_ids)[0]
batch_size = shape_list(inputs_tensor)[0]
# 5. Prepare other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
@ -678,23 +680,26 @@ class TFGenerationMixin:
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, generation_config.pad_token_id, generation_config.eos_token_id
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)
# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
if generation_config.pad_token_id is not None and tf.math.reduce_any(
input_ids[:, -1] == generation_config.pad_token_id
inputs_tensor[:, -1] == generation_config.pad_token_id
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name
)
# 6. Prepare model inputs which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
@ -702,6 +707,9 @@ class TFGenerationMixin:
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
)
else:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
# 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
@ -924,7 +932,9 @@ class TFGenerationMixin:
else:
return tf.ones(inputs.shape[:2], dtype=tf.int32)
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs_tensor: tf.Tensor, model_kwargs) -> Dict[str, Any]:
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: tf.Tensor, model_kwargs, model_input_name: Optional[str] = None
) -> Dict[str, Any]:
# get encoder and store encoder outputs
encoder = self.get_encoder()
@ -938,7 +948,9 @@ class TFGenerationMixin:
# vision models don't use `attention_mask`.
encoder_kwargs["return_dict"] = True
encoder_kwargs[self.main_input_name] = inputs_tensor
encoder_kwargs[model_input_name] = inputs_tensor
if model_input_name != self.main_input_name: # in Keras, the first input must always be passed
encoder_kwargs[self.main_input_name] = None
encoder_outputs = encoder(**encoder_kwargs)
model_kwargs["encoder_outputs"] = encoder_outputs
@ -1007,19 +1019,79 @@ class TFGenerationMixin:
return input_ids, model_kwargs
def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None):
# TODO(Patrick) - adapt this function when making `generate` more flexible
# for all kinds of input types
if inputs is None:
# if no `inputs` are passed create prompt of size (1,1) filled with BOS token
if not isinstance(bos_token_id, int) or bos_token_id < 0:
raise ValueError(
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
return tf.cast(tf.fill((1, 1), bos_token_id), dtype=tf.int32)
def _prepare_model_inputs(
self,
inputs: Optional[tf.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
) -> Tuple[tf.Tensor, Optional[str], Dict[str, tf.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if (
self.config.is_encoder_decoder
and hasattr(self, "encoder")
and hasattr(self.encoder, "main_input_name")
and self.encoder.main_input_name != self.main_input_name
):
input_name = self.encoder.main_input_name
else:
input_name = self.main_input_name
return inputs
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed."
f"Make sure to either pass {inputs} or {input_name}=..."
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg
# 3. In the presence of `inputs_embeds` for text models:
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
)
if not has_inputs_embeds_forwarding:
raise ValueError(
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
if inputs is None:
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
return inputs, input_name, model_kwargs
def _prepare_input_ids_for_generation(
self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
) -> tf.Tensor:
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
return tf.ones(shape, dtype=tf.int32) * -100
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return tf.ones((1, 1), dtype=tf.int32) * bos_token_id
@staticmethod
def _extract_past_from_model_output(outputs: ModelOutput):

View File

@ -5,11 +5,13 @@ Framework agnostic tests for generate()-related methods.
import numpy as np
from transformers import AutoTokenizer
from transformers.testing_utils import torch_device
class GenerationIntegrationTestsMixin:
# To be populated by the child classes
framework_dependent_parameters = {
"AutoModelForCausalLM": None,
"AutoModelForSeq2SeqLM": None,
"LogitsProcessorList": None,
"MinLengthLogitsProcessor": None,
@ -60,3 +62,91 @@ class GenerationIntegrationTestsMixin:
bart_model.config.min_length = None
bart_model.generate(input_ids, logits_processor=logits_processor)
def test_max_new_tokens_encoder_decoder(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart")
input_ids = bart_tokenizer(article, return_tensors=return_tensors).input_ids
if is_pt:
bart_model = bart_model.to(torch_device)
input_ids = input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 29])
max_new_tokens = 3
bart_model.config.max_length = 20
bart_model.config.eos_token_id = None
# Encoder decoder call
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
# 29 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 32])
# Encoder decoder call > 20
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
def test_max_new_tokens_decoder_only(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
article = """Justin Timberlake."""
gpt2_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
gpt2_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
input_ids = gpt2_tokenizer(article, return_tensors=return_tensors).input_ids
if is_pt:
gpt2_model = gpt2_model.to(torch_device)
input_ids = input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 9])
max_new_tokens = 3
gpt2_model.config.max_length = 20
# call < 20
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens)
# 9 input_ids + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 12])
# call > 20
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20)
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
def test_encoder_decoder_generate_with_inputs_embeds(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5)
model.config.eos_token_id = None
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
if is_pt:
model = model.to(torch_device)
input_ids = input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids)
output_sequences = model.generate(inputs_embeds=inputs_embeds)
# make sure model generated correctly until `max_length`
self.assertEqual(output_sequences.shape, (1, 5))

View File

@ -135,6 +135,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
if is_tf_available():
framework_dependent_parameters = {
"AutoModelForCausalLM": TFAutoModelForCausalLM,
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
"LogitsProcessorList": TFLogitsProcessorList,
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,

View File

@ -40,7 +40,6 @@ if is_torch_available():
ImageGPTForCausalImageModeling,
Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
VisionEncoderDecoderModel,
top_k_top_p_filtering,
)
@ -1792,6 +1791,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
if is_torch_available():
framework_dependent_parameters = {
"AutoModelForCausalLM": AutoModelForCausalLM,
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
"LogitsProcessorList": LogitsProcessorList,
"MinLengthLogitsProcessor": MinLengthLogitsProcessor,
@ -2094,182 +2094,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
output = generator(prompt, stop_sequence=" number")
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
def test_max_new_tokens_encoder_decoder(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 29])
max_new_tokens = 3
bart_model.config.max_length = 20
bart_model.config.eos_token_id = None
# Encoder decoder call
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
# 29 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 32])
# Encoder decoder call > 20
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
t5_model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5").to(torch_device)
input_ids = t5_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 56])
max_new_tokens = 3
t5_model.config.max_length = 20
t5_model.config.eos_token_id = None
# Encoder decoder call
outputs = t5_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = t5_model.generate(
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
)
# 56 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 59])
# Encoder decoder call > 20
outputs = t5_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 29])
max_new_tokens = 3
bart_model.config.max_length = 20
bart_model.config.eos_token_id = None
# Encoder decoder call
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = bart_model.generate(
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
)
# 29 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 32])
# Encoder decoder call > 20
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
article = """Justin Timberlake."""
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
gptj_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gptj").to(torch_device)
input_ids = gptj_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 9])
max_new_tokens = 3
gptj_model.config.max_length = 20
# call < 20
outputs = gptj_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 9 input_ids + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 12])
# call > 20
outputs = gptj_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 9])
max_new_tokens = 3
gpt2_model.config.max_length = 20
# call < 20
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 9 input_ids + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 12])
# call > 20
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
def test_max_new_tokens_decoder_only(self):
article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 9])
max_new_tokens = 3
gpt2_model.config.max_length = 20
# call < 20
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens)
# 9 input_ids + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 12])
# call > 20
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20)
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
torch_device
)
model.config.eos_token_id = None
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids)
output_sequences = model.generate(inputs_embeds=inputs_embeds)
# make sure model generated correctly until `max_length`
self.assertEqual(output_sequences.shape, (1, 5))
def test_encoder_decoder_generate_attention_mask(self):
articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")