mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix contrastive search to correctly handle input with padding (#33507)
* fix: handle padding in contrastive search for decoder-only models * fix: handle padding in contrastive search for encoder-decoder models * tests: move padding contrastive test to test_util, add t5 test * fix: handle if model_kwargs["decoder_attention_mask"] is None * refactor: improve padding input contrastive search generation tests * chore: _ranking_fast to use LongTensor for cosine_matrix_mask
This commit is contained in:
parent
c0c6815dc9
commit
dc8b6eaeee
@ -2604,6 +2604,15 @@ class GenerationMixin:
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
# Create cosine_matrix_mask based on the attention_mask
|
||||
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
if self.config.is_encoder_decoder:
|
||||
if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None:
|
||||
cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
|
||||
else:
|
||||
cosine_matrix_mask = model_kwargs["attention_mask"]
|
||||
cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0)
|
||||
|
||||
this_peer_finished = False
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
@ -2771,7 +2780,12 @@ class GenerationMixin:
|
||||
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
|
||||
# model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't
|
||||
# introduce (noticeable) slowdowns on single-device runs.
|
||||
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
|
||||
selected_idx = _ranking_fast(
|
||||
context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k
|
||||
)
|
||||
cosine_matrix_mask = torch.cat(
|
||||
[cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
selected_idx = selected_idx.to("cpu")
|
||||
|
||||
# This will be used instead of the previous inneficient torch.stack(torch.split())
|
||||
@ -4283,6 +4297,7 @@ def _ranking_fast(
|
||||
context_hidden: torch.FloatTensor,
|
||||
next_hidden: torch.FloatTensor,
|
||||
next_top_k_probs: torch.FloatTensor,
|
||||
cosine_matrix_mask: torch.LongTensor,
|
||||
alpha: float,
|
||||
beam_width: int,
|
||||
) -> torch.FloatTensor:
|
||||
@ -4294,6 +4309,13 @@ def _ranking_fast(
|
||||
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
|
||||
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
|
||||
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
|
||||
|
||||
# Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
|
||||
# Using a large negative value for masked positions
|
||||
cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype)
|
||||
cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min
|
||||
cosine_matrix = cosine_matrix + cosine_matrix_mask
|
||||
|
||||
degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
|
||||
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
|
||||
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
|
||||
|
@ -44,6 +44,7 @@ from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
@ -59,6 +60,7 @@ if is_torch_available():
|
||||
GPT2Tokenizer,
|
||||
ImageGPTForCausalImageModeling,
|
||||
SpeechEncoderDecoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
|
||||
from transformers.generation import (
|
||||
@ -3644,6 +3646,139 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
|
||||
@slow
|
||||
def test_padding_input_contrastive_search_gpt2(self):
|
||||
# Load the pre-trained GPT-2 model and tokenizer
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
|
||||
model.to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True)
|
||||
|
||||
# Set the tokenizer to left-pad the sequences
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Define the PAD token as the EOS token
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
|
||||
# Define the input prompt
|
||||
prompt_text = "The whispered legends of the haunted mansion spoke"
|
||||
|
||||
# Tokenize the input prompt
|
||||
encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True)
|
||||
input_ids = encoded_prompt.input_ids.to(torch_device)
|
||||
attention_mask = encoded_prompt.attention_mask.to(torch_device)
|
||||
|
||||
# Define the contrastive search params
|
||||
penalty_alpha = 0.6
|
||||
top_k = 4
|
||||
|
||||
# Define the padding length to add to the input IDs and attention mask
|
||||
padding_length = 10
|
||||
|
||||
# Generate text without padding
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
penalty_alpha=penalty_alpha,
|
||||
top_k=top_k,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
# Pad the input IDs and attention mask on the left
|
||||
padded_input_ids = F.pad(
|
||||
input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
|
||||
)
|
||||
padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0)
|
||||
|
||||
# Generate text with padded inputs
|
||||
outputs_with_padding = model.generate(
|
||||
input_ids=padded_input_ids,
|
||||
attention_mask=padded_attention_mask,
|
||||
do_sample=False,
|
||||
penalty_alpha=penalty_alpha,
|
||||
top_k=top_k,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)
|
||||
|
||||
# Assert that the generated texts are identical for padded and non-padded inputs
|
||||
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
|
||||
self.assertEqual(
|
||||
generated_text_with_padding,
|
||||
'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling '
|
||||
'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been '
|
||||
'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea',
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_padding_input_contrastive_search_t5(self):
|
||||
# Load the pre-trained T5 model and tokenizer
|
||||
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
model.to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True)
|
||||
|
||||
# Define the input prompt
|
||||
prompt_text = "translate English to German: I need to finish this task before the end of the day."
|
||||
|
||||
# Tokenize the input prompt
|
||||
encoded_prompt = tokenizer(prompt_text, return_tensors="pt")
|
||||
input_ids = encoded_prompt.input_ids.to(torch_device)
|
||||
attention_mask = encoded_prompt.attention_mask.to(torch_device)
|
||||
|
||||
# Define the decoder prompt
|
||||
decoder_prompt_text = "Ich muss diese Aufgabe"
|
||||
encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device)
|
||||
decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device)
|
||||
|
||||
# Define the contrastive search params
|
||||
penalty_alpha = 0.6
|
||||
top_k = 4
|
||||
|
||||
# Generate text without padding
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
do_sample=False,
|
||||
penalty_alpha=penalty_alpha,
|
||||
top_k=top_k,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
# Define the padding length to add to the input IDs and attention mask
|
||||
padding_length = 10
|
||||
|
||||
# Pad the decoder input IDs and attention mask on the left
|
||||
padded_decoder_input_ids = F.pad(
|
||||
decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
|
||||
)
|
||||
padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0)
|
||||
# Since the decoder_start_token_id is the same as the pad_token_id,
|
||||
# the last padded token represents the decoder start token.
|
||||
# Set the attention mask for the decoder_start_token_id to True (1).
|
||||
padded_decoder_attention_mask[:, padding_length - 1] = 1
|
||||
# Generate text with padded inputs
|
||||
outputs_with_padding = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=padded_decoder_input_ids,
|
||||
decoder_attention_mask=padded_decoder_attention_mask,
|
||||
do_sample=False,
|
||||
penalty_alpha=penalty_alpha,
|
||||
top_k=top_k,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)
|
||||
|
||||
# Assert that the generated texts are identical for padded and non-padded inputs
|
||||
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
|
||||
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user