mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
feat: Sequential beam search (#26304)
This commit is contained in:
parent
268fc1fdfa
commit
d4fc1eb498
@ -200,7 +200,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||||
prompt, usually at the expense of poorer quality.
|
prompt, usually at the expense of poorer quality.
|
||||||
low_memory (`bool`, *optional*):
|
low_memory (`bool`, *optional*):
|
||||||
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.
|
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
|
||||||
|
Used with beam search and contrastive search.
|
||||||
|
|
||||||
|
|
||||||
> Parameters that define the output variables of `generate`
|
> Parameters that define the output variables of `generate`
|
||||||
|
@ -1558,6 +1558,7 @@ class GenerationMixin:
|
|||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
|
sequential=generation_config.low_memory,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1951,8 +1952,7 @@ class GenerationMixin:
|
|||||||
model_kwargs["past_key_values"] = tuple(new_key_values)
|
model_kwargs["past_key_values"] = tuple(new_key_values)
|
||||||
|
|
||||||
if sequential:
|
if sequential:
|
||||||
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
|
all_outputs = []
|
||||||
all_last_hstates, all_hstates, all_logits = [], [], []
|
|
||||||
for i in range(top_k):
|
for i in range(top_k):
|
||||||
# compute the candidate tokens by the language model and collect their hidden_states
|
# compute the candidate tokens by the language model and collect their hidden_states
|
||||||
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
|
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
|
||||||
@ -1963,32 +1963,8 @@ class GenerationMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
for key in all_outputs:
|
all_outputs.append(outputs)
|
||||||
all_outputs[key].append(outputs[key])
|
outputs = stack_model_outputs(all_outputs)
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
|
||||||
next_hidden = outputs.decoder_hidden_states[-1]
|
|
||||||
full_hidden_states = outputs.decoder_hidden_states
|
|
||||||
|
|
||||||
else:
|
|
||||||
next_hidden = outputs.hidden_states[-1]
|
|
||||||
full_hidden_states = outputs.hidden_states
|
|
||||||
|
|
||||||
all_last_hstates.append(torch.squeeze(next_hidden, 0))
|
|
||||||
all_hstates.append(full_hidden_states)
|
|
||||||
all_logits.append(outputs.logits[:, -1, :])
|
|
||||||
|
|
||||||
# stack hidden states
|
|
||||||
next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0)
|
|
||||||
final_full_hstates = [0 for i in range(len(full_hidden_states))]
|
|
||||||
for layer in range(len(full_hidden_states)):
|
|
||||||
final_full_hstates[layer] = torch.stack(
|
|
||||||
[torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0
|
|
||||||
)
|
|
||||||
full_hidden_states = tuple(final_full_hstates)
|
|
||||||
|
|
||||||
# stack logits
|
|
||||||
logits = torch.cat(all_logits, dim=0)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# compute the candidate tokens by the language model and collect their hidden_states
|
# compute the candidate tokens by the language model and collect their hidden_states
|
||||||
@ -2001,15 +1977,15 @@ class GenerationMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
# name is different for encoder-decoder and decoder-only models
|
# name is different for encoder-decoder and decoder-only models
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
next_hidden = outputs.decoder_hidden_states[-1]
|
next_hidden = outputs.decoder_hidden_states[-1]
|
||||||
full_hidden_states = outputs.decoder_hidden_states
|
full_hidden_states = outputs.decoder_hidden_states
|
||||||
else:
|
else:
|
||||||
next_hidden = outputs.hidden_states[-1]
|
next_hidden = outputs.hidden_states[-1]
|
||||||
full_hidden_states = outputs.hidden_states
|
full_hidden_states = outputs.hidden_states
|
||||||
|
|
||||||
logits = outputs.logits[:, -1, :]
|
logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
|
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
|
||||||
|
|
||||||
@ -2747,6 +2723,7 @@ class GenerationMixin:
|
|||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: bool = False,
|
synced_gpus: bool = False,
|
||||||
|
sequential: Optional[bool] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@ -2792,6 +2769,10 @@ class GenerationMixin:
|
|||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||||
|
sequential (`bool`, defaults to `False`):
|
||||||
|
By default, beam search has `batch_size * num_beams` as effective batch size (see `beam_search()` for
|
||||||
|
more details). This flag will avoid parallelizing the beam search and will instead run beam search
|
||||||
|
sequentially.
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
@ -2858,6 +2839,7 @@ class GenerationMixin:
|
|||||||
# init values
|
# init values
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
sequential = sequential if sequential is not None else self.generation_config.low_memory
|
||||||
if max_length is not None:
|
if max_length is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`max_length` is deprecated in this function, use"
|
"`max_length` is deprecated in this function, use"
|
||||||
@ -2932,12 +2914,39 @@ class GenerationMixin:
|
|||||||
|
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||||
|
|
||||||
outputs = self(
|
# if sequential is True, split the input to batches of batch_size and run sequentially
|
||||||
**model_inputs,
|
if sequential:
|
||||||
return_dict=True,
|
if any(
|
||||||
output_attentions=output_attentions,
|
model_name in self.__class__.__name__.lower()
|
||||||
output_hidden_states=output_hidden_states,
|
for model_name in ["fsmt", "reformer", "bloom", "ctrl", "gpt_bigcode", "transo_xl", "xlnet", "cpm"]
|
||||||
)
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Currently generation for {self.__class__.__name__} is not supported "
|
||||||
|
f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_per_sub_batches = _split_model_inputs(
|
||||||
|
model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
|
||||||
|
)
|
||||||
|
outputs_per_sub_batch = [
|
||||||
|
self(
|
||||||
|
**inputs_per_sub_batch,
|
||||||
|
return_dict=True,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
for inputs_per_sub_batch in inputs_per_sub_batches
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = stack_model_outputs(outputs_per_sub_batch)
|
||||||
|
|
||||||
|
else: # Unchanged original behavior
|
||||||
|
outputs = self(
|
||||||
|
**model_inputs,
|
||||||
|
return_dict=True,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
if synced_gpus and this_peer_finished:
|
if synced_gpus and this_peer_finished:
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
@ -4656,3 +4665,139 @@ def _ranking_fast(
|
|||||||
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
|
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
|
||||||
_, selected_idx = contrastive_score.max(dim=-1) # [B]
|
_, selected_idx = contrastive_score.max(dim=-1) # [B]
|
||||||
return selected_idx
|
return selected_idx
|
||||||
|
|
||||||
|
|
||||||
|
def _split(data, full_batch_size: int, split_size: int = None):
|
||||||
|
"""
|
||||||
|
Takes care of three cases:
|
||||||
|
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim
|
||||||
|
2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and
|
||||||
|
return a list of tuples
|
||||||
|
3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and
|
||||||
|
return a list of tuples of tuples
|
||||||
|
(see documentation of ModelOutput)
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
return [None] * (full_batch_size // split_size)
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
|
||||||
|
elif isinstance(data, tuple):
|
||||||
|
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||||
|
if isinstance(data[0], tuple):
|
||||||
|
return [
|
||||||
|
tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data)
|
||||||
|
for i in range(0, full_batch_size, split_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
tuple(sub_tensor[i : i + split_size] for sub_tensor in data)
|
||||||
|
for i in range(0, full_batch_size, split_size)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected attribute type: {type(data)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _split_model_inputs(
|
||||||
|
model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int
|
||||||
|
) -> List[Union[ModelOutput, Dict]]:
|
||||||
|
"""
|
||||||
|
Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split
|
||||||
|
size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from
|
||||||
|
previous forward pass.
|
||||||
|
"""
|
||||||
|
# Edge case: if model_input is None, return a list of Nones
|
||||||
|
# this happens with Whisper where encoder_outputs is None
|
||||||
|
if model_input is None:
|
||||||
|
return [model_input] * (full_batch_size // split_size)
|
||||||
|
# Infer the class from the object
|
||||||
|
model_output_cls = type(model_input)
|
||||||
|
if (full_batch_size % split_size) != 0:
|
||||||
|
raise ValueError("`full_batch_size` must be divisible by `split_size`")
|
||||||
|
|
||||||
|
if split_size > full_batch_size:
|
||||||
|
raise ValueError("`split_size` must be smaller or equal to `full_batch_size`")
|
||||||
|
|
||||||
|
# Helper function to split tensors or tuples of tensors
|
||||||
|
|
||||||
|
# Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them
|
||||||
|
keys = (
|
||||||
|
model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys()
|
||||||
|
)
|
||||||
|
# We only keep keys that are in the model_input
|
||||||
|
keys = [k for k in keys if k in model_input]
|
||||||
|
# Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a
|
||||||
|
# ModelOutput object.
|
||||||
|
# bool should not be split but replicated for each split
|
||||||
|
bool_keys = [k for k in keys if isinstance(model_input[k], bool)]
|
||||||
|
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"]
|
||||||
|
|
||||||
|
# we split the tensors and tuples of tensors
|
||||||
|
data_split_list = [
|
||||||
|
{k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
|
||||||
|
for i in range(full_batch_size // split_size)
|
||||||
|
]
|
||||||
|
# bool values are the same and replicated for each split
|
||||||
|
bool_data = {k: model_input[k] for k in bool_keys}
|
||||||
|
# encoder_outputs is a ModelOutput object and should be split by its own
|
||||||
|
if "encoder_outputs" in model_input:
|
||||||
|
encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size)
|
||||||
|
data_split_list = [
|
||||||
|
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert each dictionary in the list to an object of the inferred class
|
||||||
|
split_model_inputs: List[Union[ModelOutput, Dict]] = [
|
||||||
|
model_output_cls(**data_split, **bool_data) for data_split in data_split_list
|
||||||
|
]
|
||||||
|
|
||||||
|
return split_model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
|
||||||
|
"""
|
||||||
|
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
|
||||||
|
specific ModelOutput subclass from the list provided.
|
||||||
|
"""
|
||||||
|
if not model_outputs:
|
||||||
|
raise ValueError("Input list is empty.")
|
||||||
|
|
||||||
|
# Infer the class from the first object in the list
|
||||||
|
model_output_cls = type(model_outputs[0])
|
||||||
|
|
||||||
|
# Ensure all objects are of the same type
|
||||||
|
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
|
||||||
|
raise ValueError("All elements in the list should be of the same type.")
|
||||||
|
|
||||||
|
# Helper function to concat tensors or tuples of tensors
|
||||||
|
def _concat(data):
|
||||||
|
"""
|
||||||
|
Reverse of `_split` function above.
|
||||||
|
"""
|
||||||
|
if any(data is None for data in data):
|
||||||
|
return None
|
||||||
|
if isinstance(data[0], torch.Tensor):
|
||||||
|
return torch.cat(data, dim=0)
|
||||||
|
elif isinstance(data[0], tuple):
|
||||||
|
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||||
|
if isinstance(data[0][0], tuple):
|
||||||
|
return tuple(
|
||||||
|
tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
|
||||||
|
for i in range(len(data[0]))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
|
||||||
|
elif isinstance(data[0], (int, float)):
|
||||||
|
# If the elements are integers or floats, return a tensor
|
||||||
|
return torch.tensor(data)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected attribute type: {type(data[0])}")
|
||||||
|
|
||||||
|
# Use a dictionary comprehension to gather attributes from all objects and concatenate them
|
||||||
|
concatenated_data = {
|
||||||
|
k: _concat([getattr(model_output, k) for model_output in model_outputs])
|
||||||
|
for k in model_output_cls.__dataclass_fields__.keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Return a new object of the inferred class with the concatenated attributes
|
||||||
|
return model_output_cls(**concatenated_data)
|
||||||
|
@ -1539,6 +1539,39 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
|
def test_beam_search_low_memory(self):
|
||||||
|
# Check that choosing 'low_memory' does not change the model output
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||||
|
self.skipTest("Won't fix: old model with different cache format")
|
||||||
|
if any(
|
||||||
|
model_name in model_class.__name__.lower()
|
||||||
|
for model_name in [
|
||||||
|
"bloom",
|
||||||
|
"ctrl",
|
||||||
|
"gptbigcode",
|
||||||
|
"transo_xl",
|
||||||
|
"xlnet",
|
||||||
|
"cpm",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
self.skipTest("May fix in the future: need model-specific fixes")
|
||||||
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2)
|
||||||
|
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||||
|
|
||||||
|
config.use_cache = True
|
||||||
|
config.is_decoder = True
|
||||||
|
|
||||||
|
# test output equality of low versus high memory
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
|
low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)
|
||||||
|
|
||||||
|
high_output = model.generate(
|
||||||
|
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
|
||||||
|
)
|
||||||
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
||||||
def test_assisted_decoding_matches_greedy_search(self):
|
def test_assisted_decoding_matches_greedy_search(self):
|
||||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||||
@ -2766,6 +2799,19 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||||
|
|
||||||
|
def test_beam_search_low_memory(self):
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]
|
||||||
|
|
||||||
|
low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)
|
||||||
|
|
||||||
|
high_output = model.generate(
|
||||||
|
model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
|
||||||
|
)
|
||||||
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_beam_search_example_integration(self):
|
def test_beam_search_example_integration(self):
|
||||||
# PT-only test: TF doesn't have a BeamSearchScorer
|
# PT-only test: TF doesn't have a BeamSearchScorer
|
||||||
|
Loading…
Reference in New Issue
Block a user