mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Add inputs_embeds
support when generating with GPT-J (#21575)
This commit is contained in:
parent
dcb5e01197
commit
93ed89bf40
@ -771,7 +771,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.lm_head = new_embeddings
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if past_key_values:
|
if past_key_values:
|
||||||
@ -790,14 +790,24 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
position_ids = None
|
position_ids = None
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
"past_key_values": past_key_values,
|
if inputs_embeds is not None and past_key_values is None:
|
||||||
"use_cache": kwargs.get("use_cache"),
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
"position_ids": position_ids,
|
else:
|
||||||
"attention_mask": attention_mask,
|
model_inputs = {"input_ids": input_ids}
|
||||||
"token_type_ids": token_type_ids,
|
|
||||||
}
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
|
Loading…
Reference in New Issue
Block a user