mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Generate: add Bloom fixes for contrastive search (#20213)
This commit is contained in:
parent
fda125638f
commit
938cb04789
@ -672,8 +672,7 @@ class GenerationMixin:
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _extract_past_from_model_output(outputs: ModelOutput):
|
||||
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
|
||||
past = None
|
||||
if "past_key_values" in outputs:
|
||||
past = outputs.past_key_values
|
||||
@ -681,13 +680,24 @@ class GenerationMixin:
|
||||
past = outputs.mems
|
||||
elif "past_buckets_states" in outputs:
|
||||
past = outputs.past_buckets_states
|
||||
|
||||
# Bloom fix: standardizes the cache format when requested
|
||||
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
|
||||
batch_size = outputs.logits.shape[0]
|
||||
past = self._convert_to_standard_cache(past, batch_size=batch_size)
|
||||
return past
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
model_kwargs: Dict[str, Any],
|
||||
is_encoder_decoder: bool = False,
|
||||
standardize_cache_format: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
# update past
|
||||
model_kwargs["past"] = self._extract_past_from_model_output(outputs)
|
||||
model_kwargs["past"] = self._extract_past_from_model_output(
|
||||
outputs, standardize_cache_format=standardize_cache_format
|
||||
)
|
||||
|
||||
# update token_type_ids with last value
|
||||
if "token_type_ids" in model_kwargs:
|
||||
@ -1939,7 +1949,10 @@ class GenerationMixin:
|
||||
logit_for_next_step = outputs.logits[:, -1, :]
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
standardize_cache_format=True,
|
||||
)
|
||||
|
||||
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
|
||||
@ -2001,7 +2014,7 @@ class GenerationMixin:
|
||||
outputs = self(
|
||||
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
|
||||
)
|
||||
next_past_key_values = self._extract_past_from_model_output(outputs)
|
||||
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
|
||||
|
||||
logits = outputs.logits[:, -1, :]
|
||||
# name is different for encoder-decoder and decoder-only models
|
||||
|
@ -506,6 +506,45 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
if isinstance(module, BloomModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_standard_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
||||
num_heads, ...]))
|
||||
"""
|
||||
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
num_heads = batch_size_times_num_heads // batch_size
|
||||
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
||||
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_bloom_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||
"""
|
||||
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
batch_size_times_num_heads = batch_size * num_heads
|
||||
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
||||
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
|
||||
BLOOM_START_DOCSTRING = r"""
|
||||
|
||||
@ -811,6 +850,10 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
|
||||
if past[0][0].shape[0] == input_ids.shape[0]:
|
||||
past = self._convert_to_bloom_cache(past)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
@ -896,9 +939,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(
|
||||
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||
@ -907,28 +949,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
|
||||
Output shares the same memory storage as `past`.
|
||||
"""
|
||||
batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape
|
||||
batch_size = len(beam_idx)
|
||||
num_heads = batch_size_times_num_heads // batch_size
|
||||
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
||||
|
||||
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
||||
device_to_beam_idx = {
|
||||
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
||||
}
|
||||
# key: layer_past[0] [batch_size * num_heads, head_dim, seq_length]
|
||||
# value: layer_past[1] [batch_size * num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
reordered_past = tuple(
|
||||
(
|
||||
layer_past[0]
|
||||
.view(batch_size, num_heads, head_dim, seq_length)
|
||||
.index_select(0, device_to_beam_idx[layer_past[0].device])
|
||||
.view(batch_size_times_num_heads, head_dim, seq_length),
|
||||
layer_past[1]
|
||||
.view(batch_size, num_heads, seq_length, head_dim)
|
||||
.index_select(0, device_to_beam_idx[layer_past[0].device])
|
||||
.view(batch_size_times_num_heads, seq_length, head_dim),
|
||||
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||
)
|
||||
for layer_past in past
|
||||
for layer_past in standardized_past
|
||||
)
|
||||
return self._convert_to_bloom_cache(reordered_past)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -1411,9 +1411,8 @@ class GenerationTesterMixin:
|
||||
# check `generate()` and `contrastive_search()` are equal
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
return
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
@ -1434,9 +1433,8 @@ class GenerationTesterMixin:
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
return
|
||||
|
||||
# enable cache
|
||||
|
Loading…
Reference in New Issue
Block a user