mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add batch inferencing support for GPT2LMHeadModel (#7552)
* Add support for gpt2 batch inferencing * add test * remove typo Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
parent
0c64b18840
commit
121dd4332b
@ -701,10 +701,23 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
if past:
|
if past:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create postion_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past:
|
||||||
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
position_ids = None
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||||
|
@ -33,6 +33,7 @@ if is_torch_available():
|
|||||||
GPT2ForSequenceClassification,
|
GPT2ForSequenceClassification,
|
||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
GPT2Model,
|
GPT2Model,
|
||||||
|
GPT2Tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -425,6 +426,50 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
|
||||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
|
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_batch_generation(self):
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
|
model.to(torch_device)
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
# Define PAD Token = EOS Token = 50256
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
model.config.pad_token_id = model.config.eos_token_id
|
||||||
|
|
||||||
|
# use different length sentences to test batching
|
||||||
|
sentences = [
|
||||||
|
"Hello, my dog is a little",
|
||||||
|
"Today, I",
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids=inputs["input_ids"].to(torch_device),
|
||||||
|
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||||
|
|
||||||
|
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||||
|
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||||
|
|
||||||
|
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||||
|
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
expected_output_sentence = [
|
||||||
|
"Hello, my dog is a little bit of a mess. I'm not sure if he's going",
|
||||||
|
"Today, I'm going to be doing a lot of research on this. I",
|
||||||
|
]
|
||||||
|
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||||
|
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
Loading…
Reference in New Issue
Block a user