From a81fe4e1df788a65c0bcd83be824d0ac93e7fd05 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 14 Feb 2023 14:16:22 +0000 Subject: [PATCH] Generate: input expansion for any model input (#21624) --- src/transformers/generation/tf_utils.py | 81 ++++++++++----------- src/transformers/generation/utils.py | 24 +++--- tests/models/blip_2/test_modeling_blip_2.py | 20 ++--- 3 files changed, 58 insertions(+), 67 deletions(-) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index fe2d5ce5419..e2c5781f7eb 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -986,17 +986,13 @@ class TFGenerationMixin: ) # 11. broadcast inputs to the desired number of beams - input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams) - - if "encoder_outputs" in model_kwargs: - model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( - model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams - ) - - if "attention_mask" in model_kwargs: - model_kwargs["attention_mask"] = self._expand_to_num_beams( - model_kwargs["attention_mask"], num_beams=generation_config.num_beams - ) + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + expand_in_new_axis=True, + **model_kwargs, + ) # 12. run beam search return self.beam_search( @@ -1025,17 +1021,13 @@ class TFGenerationMixin: logits_warper = self._get_logits_warper(generation_config=generation_config) # 12. broadcast inputs to the desired number of beams - input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams) - - if "encoder_outputs" in model_kwargs: - model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( - model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams - ) - - if "attention_mask" in model_kwargs: - model_kwargs["attention_mask"] = self._expand_to_num_beams( - model_kwargs["attention_mask"], num_beams=generation_config.num_beams - ) + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + expand_in_new_axis=True, + **model_kwargs, + ) # 13. run beam sample (beam search with sampling) return self.beam_search( @@ -1054,11 +1046,6 @@ class TFGenerationMixin: **model_kwargs, ) - @staticmethod - def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor: - shape = shape_list(tensor) - return tf.broadcast_to(tensor[:, None], (shape[0], num_beams) + tuple(shape[1:])) - def _prepare_attention_mask_for_generation( self, inputs: tf.Tensor, @@ -1142,29 +1129,37 @@ class TFGenerationMixin: expand_size: int = 1, is_encoder_decoder: bool = False, input_ids: Optional[tf.Tensor] = None, + expand_in_new_axis: bool = False, **model_kwargs, ) -> Tuple[tf.Tensor, Dict[str, Any]]: - """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + """ + Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] or [batch_size, expand_size, ...], + depending on `expand_in_new_axis`. Beam-based approaches expect this function to be used with + `expand_in_new_axis=True` + """ + + def _expand_tensor(tensor: tf.Tensor): + if expand_in_new_axis: + shape = shape_list(tensor) + return tf.broadcast_to(tensor[:, None], (shape[0], expand_size) + tuple(shape[1:])) + else: + return tf.repeat(tensor, expand_size, axis=0) + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], tf.Tensor): + dict_to_expand[key] = _expand_tensor(dict_to_expand[key]) + return dict_to_expand + if input_ids is not None: - input_ids = tf.repeat(input_ids, expand_size, axis=0) + input_ids = _expand_tensor(input_ids) - if model_kwargs.get("token_type_ids") is not None: - model_kwargs["token_type_ids"] = tf.repeat(model_kwargs["token_type_ids"], expand_size, axis=0) - - if model_kwargs.get("attention_mask") is not None: - model_kwargs["attention_mask"] = tf.repeat(model_kwargs["attention_mask"], expand_size, axis=0) - - if model_kwargs.get("decoder_attention_mask") is not None: - model_kwargs["decoder_attention_mask"] = tf.repeat( - model_kwargs["decoder_attention_mask"], expand_size, axis=0 - ) + model_kwargs = _expand_dict_for_generation(model_kwargs) if is_encoder_decoder: - encoder_outputs = model_kwargs.get("encoder_outputs") - if encoder_outputs is None: + if model_kwargs.get("encoder_outputs") is None: raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") - encoder_outputs["last_hidden_state"] = tf.repeat(encoder_outputs.last_hidden_state, expand_size, axis=0) - model_kwargs["encoder_outputs"] = encoder_outputs + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) return input_ids, model_kwargs diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ce955baef1a..61f6090a9d6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -671,26 +671,22 @@ class GenerationMixin: **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) - if model_kwargs.get("token_type_ids") is not None: - model_kwargs["token_type_ids"] = model_kwargs["token_type_ids"].repeat_interleave(expand_size, dim=0) - - if model_kwargs.get("attention_mask") is not None: - model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat_interleave(expand_size, dim=0) + model_kwargs = _expand_dict_for_generation(model_kwargs) if is_encoder_decoder: - encoder_outputs = model_kwargs.get("encoder_outputs") - if encoder_outputs is None: + if model_kwargs.get("encoder_outputs") is None: raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") - encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( - expand_size, dim=0 - ) - model_kwargs["encoder_outputs"] = encoder_outputs - decoder_attention_mask = model_kwargs.get("decoder_attention_mask") - if decoder_attention_mask is not None: - model_kwargs["decoder_attention_mask"] = decoder_attention_mask.repeat_interleave(expand_size, dim=0) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) return input_ids, model_kwargs diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index ef3bc184538..93d63104c53 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -797,7 +797,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ) self.assertEqual(generated_text, "it's not a city, it's a beach") - def test_inference_opt_batched(self): + def test_inference_opt_batched_beam_search(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device) @@ -805,11 +805,11 @@ class Blip2ModelIntegrationTest(unittest.TestCase): image = prepare_img() inputs = processor(images=[image, image], return_tensors="pt").to(torch_device) - predictions = model.generate(**inputs) + predictions = model.generate(**inputs, num_beams=2) - # Test output - self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118]) - self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118]) + # Test output (in this case, slightly different from greedy search) + self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118]) + self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118]) def test_inference_t5(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") @@ -842,7 +842,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ) self.assertEqual(generated_text, "san diego") - def test_inference_t5_batched(self): + def test_inference_t5_batched_beam_search(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device) @@ -850,8 +850,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase): image = prepare_img() inputs = processor(images=[image, image], return_tensors="pt").to(torch_device) - predictions = model.generate(**inputs) + predictions = model.generate(**inputs, num_beams=2) - # Test output - self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) - self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) + # Test output (in this case, slightly different from greedy search) + self.assertEqual(predictions[0].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1]) + self.assertEqual(predictions[1].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])