Generate: correct default model input creation for decoder-only models (#21580)

This commit is contained in:
Joao Gante 2023-02-13 17:04:49 +00:00 committed by GitHub
parent edc1e734bf
commit fa4bdb0a40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 109 additions and 17 deletions

View File

@ -845,8 +845,7 @@ class TFGenerationMixin:
model_kwargs=model_kwargs,
)
else:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = shape_list(input_ids)[-1]
@ -1214,20 +1213,34 @@ class TFGenerationMixin:
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0]
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
if inputs is None:
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
inputs = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs.get("encoder_outputs")
)
return inputs, input_name, model_kwargs
def _prepare_input_ids_for_generation(
self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[tf.Tensor] = None,
bos_token_id: Optional[int] = None,
encoder_outputs: Optional[ModelOutput] = None,
batch_size: Optional[int] = None,
) -> tf.Tensor:
"""Initializes input ids for generation, if necessary."""
if inputs is not None:
return inputs
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.shape[:-1]
@ -1235,7 +1248,9 @@ class TFGenerationMixin:
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return tf.ones((1, 1), dtype=tf.int32) * bos_token_id
batch_size = batch_size if batch_size is not None else 1
return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id
@staticmethod
def _extract_past_from_model_output(outputs: ModelOutput):

View File

@ -541,15 +541,20 @@ class GenerationMixin:
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0]
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
if inputs is None:
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
inputs = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs.get("encoder_outputs")
)
return inputs, input_name, model_kwargs
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
@ -558,9 +563,17 @@ class GenerationMixin:
"""
return logits
def _prepare_input_ids_for_generation(
self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
encoder_outputs: Optional[ModelOutput] = None,
batch_size: Optional[int] = None,
) -> torch.LongTensor:
"""Initializes input ids for generation, if necessary."""
if inputs is not None:
return inputs
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
@ -568,7 +581,9 @@ class GenerationMixin:
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
batch_size = batch_size if batch_size is not None else 1
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def _prepare_attention_mask_for_generation(
self,
@ -1258,8 +1273,7 @@ class GenerationMixin:
device=inputs_tensor.device,
)
else:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]

View File

@ -2488,3 +2488,38 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
eos_token_id = [846, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_generate_from_inputs_embeds_decoder_only(self):
# Note: the model must support generation from input embeddings
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
text = "Hello world"
tokenized_inputs = tokenizer([text, text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_from_ids = model.generate(input_ids)
self.assertEqual(outputs_from_ids.shape, (2, 20))
# Same thing, but from input embeddings
inputs_embeds = model.transformer.wte(input_ids)
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
# But if we pass different inputs_embeds, we should get different outputs
torch.manual_seed(0)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
)
self.assertListEqual(
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
outputs_from_embeds_wo_ids[:, 1:].tolist(),
)

View File

@ -797,6 +797,20 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
self.assertEqual(generated_text, "it's not a city, it's a beach")
def test_inference_opt_batched(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device)
# prepare image
image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs)
# Test output
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
def test_inference_t5(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained(
@ -827,3 +841,17 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
[0, 3, 7, 152, 67, 839, 1],
)
self.assertEqual(generated_text, "san diego")
def test_inference_t5_batched(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device)
# prepare image
image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs)
# Test output
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])