Allow bos_token_id is None during the generation with inputs_embeds (#29772)

* update

* add ut

* update
This commit is contained in:
Zhihao Lin 2024-03-26 20:51:00 +08:00 committed by GitHub
parent b9ceb03df8
commit 998b5bb56f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 3 deletions

View File

@ -436,9 +436,6 @@ class GenerationMixin:
shape = encoder_outputs.last_hidden_state.size()[:-1]
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size = 1
@ -449,6 +446,10 @@ class GenerationMixin:
if "inputs_embeds" in model_kwargs:
return torch.ones((batch_size, 0), dtype=torch.long, device=self.device)
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def _prepare_attention_mask_for_generation(

View File

@ -1467,6 +1467,17 @@ class GenerationTesterMixin:
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self):
article = "Today a dragon flew over Paris."
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids)
model.generate(inputs_embeds=inputs_embeds, max_length=20, bos_token_id=None)
with self.assertRaises(ValueError):
model.generate(max_length=20, bos_token_id=None)
def test_generate_from_inputs_embeds_decoder_only(self):
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function