mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-29 17:22:25 +06:00
Generate: input expansion for any model input (#21624)
This commit is contained in:
parent
13e03e619d
commit
a81fe4e1df
@ -986,17 +986,13 @@ class TFGenerationMixin:
|
||||
)
|
||||
|
||||
# 11. broadcast inputs to the desired number of beams
|
||||
input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)
|
||||
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
|
||||
)
|
||||
|
||||
if "attention_mask" in model_kwargs:
|
||||
model_kwargs["attention_mask"] = self._expand_to_num_beams(
|
||||
model_kwargs["attention_mask"], num_beams=generation_config.num_beams
|
||||
)
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_beams,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
expand_in_new_axis=True,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 12. run beam search
|
||||
return self.beam_search(
|
||||
@ -1025,17 +1021,13 @@ class TFGenerationMixin:
|
||||
logits_warper = self._get_logits_warper(generation_config=generation_config)
|
||||
|
||||
# 12. broadcast inputs to the desired number of beams
|
||||
input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)
|
||||
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
|
||||
)
|
||||
|
||||
if "attention_mask" in model_kwargs:
|
||||
model_kwargs["attention_mask"] = self._expand_to_num_beams(
|
||||
model_kwargs["attention_mask"], num_beams=generation_config.num_beams
|
||||
)
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_beams,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
expand_in_new_axis=True,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 13. run beam sample (beam search with sampling)
|
||||
return self.beam_search(
|
||||
@ -1054,11 +1046,6 @@ class TFGenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor:
|
||||
shape = shape_list(tensor)
|
||||
return tf.broadcast_to(tensor[:, None], (shape[0], num_beams) + tuple(shape[1:]))
|
||||
|
||||
def _prepare_attention_mask_for_generation(
|
||||
self,
|
||||
inputs: tf.Tensor,
|
||||
@ -1142,29 +1129,37 @@ class TFGenerationMixin:
|
||||
expand_size: int = 1,
|
||||
is_encoder_decoder: bool = False,
|
||||
input_ids: Optional[tf.Tensor] = None,
|
||||
expand_in_new_axis: bool = False,
|
||||
**model_kwargs,
|
||||
) -> Tuple[tf.Tensor, Dict[str, Any]]:
|
||||
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
|
||||
"""
|
||||
Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] or [batch_size, expand_size, ...],
|
||||
depending on `expand_in_new_axis`. Beam-based approaches expect this function to be used with
|
||||
`expand_in_new_axis=True`
|
||||
"""
|
||||
|
||||
def _expand_tensor(tensor: tf.Tensor):
|
||||
if expand_in_new_axis:
|
||||
shape = shape_list(tensor)
|
||||
return tf.broadcast_to(tensor[:, None], (shape[0], expand_size) + tuple(shape[1:]))
|
||||
else:
|
||||
return tf.repeat(tensor, expand_size, axis=0)
|
||||
|
||||
def _expand_dict_for_generation(dict_to_expand):
|
||||
for key in dict_to_expand:
|
||||
if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], tf.Tensor):
|
||||
dict_to_expand[key] = _expand_tensor(dict_to_expand[key])
|
||||
return dict_to_expand
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = tf.repeat(input_ids, expand_size, axis=0)
|
||||
input_ids = _expand_tensor(input_ids)
|
||||
|
||||
if model_kwargs.get("token_type_ids") is not None:
|
||||
model_kwargs["token_type_ids"] = tf.repeat(model_kwargs["token_type_ids"], expand_size, axis=0)
|
||||
|
||||
if model_kwargs.get("attention_mask") is not None:
|
||||
model_kwargs["attention_mask"] = tf.repeat(model_kwargs["attention_mask"], expand_size, axis=0)
|
||||
|
||||
if model_kwargs.get("decoder_attention_mask") is not None:
|
||||
model_kwargs["decoder_attention_mask"] = tf.repeat(
|
||||
model_kwargs["decoder_attention_mask"], expand_size, axis=0
|
||||
)
|
||||
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
||||
|
||||
if is_encoder_decoder:
|
||||
encoder_outputs = model_kwargs.get("encoder_outputs")
|
||||
if encoder_outputs is None:
|
||||
if model_kwargs.get("encoder_outputs") is None:
|
||||
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
||||
encoder_outputs["last_hidden_state"] = tf.repeat(encoder_outputs.last_hidden_state, expand_size, axis=0)
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
|
@ -671,26 +671,22 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
||||
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
|
||||
|
||||
def _expand_dict_for_generation(dict_to_expand):
|
||||
for key in dict_to_expand:
|
||||
if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor):
|
||||
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
||||
return dict_to_expand
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
||||
|
||||
if model_kwargs.get("token_type_ids") is not None:
|
||||
model_kwargs["token_type_ids"] = model_kwargs["token_type_ids"].repeat_interleave(expand_size, dim=0)
|
||||
|
||||
if model_kwargs.get("attention_mask") is not None:
|
||||
model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat_interleave(expand_size, dim=0)
|
||||
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
||||
|
||||
if is_encoder_decoder:
|
||||
encoder_outputs = model_kwargs.get("encoder_outputs")
|
||||
if encoder_outputs is None:
|
||||
if model_kwargs.get("encoder_outputs") is None:
|
||||
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||
expand_size, dim=0
|
||||
)
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
decoder_attention_mask = model_kwargs.get("decoder_attention_mask")
|
||||
if decoder_attention_mask is not None:
|
||||
model_kwargs["decoder_attention_mask"] = decoder_attention_mask.repeat_interleave(expand_size, dim=0)
|
||||
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
|
@ -797,7 +797,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(generated_text, "it's not a city, it's a beach")
|
||||
|
||||
def test_inference_opt_batched(self):
|
||||
def test_inference_opt_batched_beam_search(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device)
|
||||
|
||||
@ -805,11 +805,11 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
image = prepare_img()
|
||||
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
predictions = model.generate(**inputs, num_beams=2)
|
||||
|
||||
# Test output
|
||||
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
|
||||
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
|
||||
# Test output (in this case, slightly different from greedy search)
|
||||
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
|
||||
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
|
||||
|
||||
def test_inference_t5(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
||||
@ -842,7 +842,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(generated_text, "san diego")
|
||||
|
||||
def test_inference_t5_batched(self):
|
||||
def test_inference_t5_batched_beam_search(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
||||
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device)
|
||||
|
||||
@ -850,8 +850,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
image = prepare_img()
|
||||
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
predictions = model.generate(**inputs, num_beams=2)
|
||||
|
||||
# Test output
|
||||
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
|
||||
self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
|
||||
# Test output (in this case, slightly different from greedy search)
|
||||
self.assertEqual(predictions[0].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])
|
||||
self.assertEqual(predictions[1].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])
|
||||
|
Loading…
Reference in New Issue
Block a user