Generate: input expansion for any model input (#21624)

This commit is contained in:
Joao Gante 2023-02-14 14:16:22 +00:00 committed by GitHub
parent 13e03e619d
commit a81fe4e1df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 67 deletions

View File

@ -986,17 +986,13 @@ class TFGenerationMixin:
)
# 11. broadcast inputs to the desired number of beams
input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=generation_config.num_beams
)
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
expand_in_new_axis=True,
**model_kwargs,
)
# 12. run beam search
return self.beam_search(
@ -1025,17 +1021,13 @@ class TFGenerationMixin:
logits_warper = self._get_logits_warper(generation_config=generation_config)
# 12. broadcast inputs to the desired number of beams
input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=generation_config.num_beams
)
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
expand_in_new_axis=True,
**model_kwargs,
)
# 13. run beam sample (beam search with sampling)
return self.beam_search(
@ -1054,11 +1046,6 @@ class TFGenerationMixin:
**model_kwargs,
)
@staticmethod
def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor:
shape = shape_list(tensor)
return tf.broadcast_to(tensor[:, None], (shape[0], num_beams) + tuple(shape[1:]))
def _prepare_attention_mask_for_generation(
self,
inputs: tf.Tensor,
@ -1142,29 +1129,37 @@ class TFGenerationMixin:
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: Optional[tf.Tensor] = None,
expand_in_new_axis: bool = False,
**model_kwargs,
) -> Tuple[tf.Tensor, Dict[str, Any]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
"""
Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] or [batch_size, expand_size, ...],
depending on `expand_in_new_axis`. Beam-based approaches expect this function to be used with
`expand_in_new_axis=True`
"""
def _expand_tensor(tensor: tf.Tensor):
if expand_in_new_axis:
shape = shape_list(tensor)
return tf.broadcast_to(tensor[:, None], (shape[0], expand_size) + tuple(shape[1:]))
else:
return tf.repeat(tensor, expand_size, axis=0)
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], tf.Tensor):
dict_to_expand[key] = _expand_tensor(dict_to_expand[key])
return dict_to_expand
if input_ids is not None:
input_ids = tf.repeat(input_ids, expand_size, axis=0)
input_ids = _expand_tensor(input_ids)
if model_kwargs.get("token_type_ids") is not None:
model_kwargs["token_type_ids"] = tf.repeat(model_kwargs["token_type_ids"], expand_size, axis=0)
if model_kwargs.get("attention_mask") is not None:
model_kwargs["attention_mask"] = tf.repeat(model_kwargs["attention_mask"], expand_size, axis=0)
if model_kwargs.get("decoder_attention_mask") is not None:
model_kwargs["decoder_attention_mask"] = tf.repeat(
model_kwargs["decoder_attention_mask"], expand_size, axis=0
)
model_kwargs = _expand_dict_for_generation(model_kwargs)
if is_encoder_decoder:
encoder_outputs = model_kwargs.get("encoder_outputs")
if encoder_outputs is None:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
encoder_outputs["last_hidden_state"] = tf.repeat(encoder_outputs.last_hidden_state, expand_size, axis=0)
model_kwargs["encoder_outputs"] = encoder_outputs
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
return input_ids, model_kwargs

View File

@ -671,26 +671,22 @@ class GenerationMixin:
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
if model_kwargs.get("token_type_ids") is not None:
model_kwargs["token_type_ids"] = model_kwargs["token_type_ids"].repeat_interleave(expand_size, dim=0)
if model_kwargs.get("attention_mask") is not None:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat_interleave(expand_size, dim=0)
model_kwargs = _expand_dict_for_generation(model_kwargs)
if is_encoder_decoder:
encoder_outputs = model_kwargs.get("encoder_outputs")
if encoder_outputs is None:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
expand_size, dim=0
)
model_kwargs["encoder_outputs"] = encoder_outputs
decoder_attention_mask = model_kwargs.get("decoder_attention_mask")
if decoder_attention_mask is not None:
model_kwargs["decoder_attention_mask"] = decoder_attention_mask.repeat_interleave(expand_size, dim=0)
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
return input_ids, model_kwargs

View File

@ -797,7 +797,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
self.assertEqual(generated_text, "it's not a city, it's a beach")
def test_inference_opt_batched(self):
def test_inference_opt_batched_beam_search(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device)
@ -805,11 +805,11 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs)
predictions = model.generate(**inputs, num_beams=2)
# Test output
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
# Test output (in this case, slightly different from greedy search)
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
def test_inference_t5(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
@ -842,7 +842,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
self.assertEqual(generated_text, "san diego")
def test_inference_t5_batched(self):
def test_inference_t5_batched_beam_search(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device)
@ -850,8 +850,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs)
predictions = model.generate(**inputs, num_beams=2)
# Test output
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
# Test output (in this case, slightly different from greedy search)
self.assertEqual(predictions[0].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])
self.assertEqual(predictions[1].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])