mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add batch inferencing support for GPT2LMHeadModel (#7552)
* Add support for gpt2 batch inferencing * add test * remove typo Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
parent
0c64b18840
commit
121dd4332b
@ -701,10 +701,23 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create postion_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||
|
@ -33,6 +33,7 @@ if is_torch_available():
|
||||
GPT2ForSequenceClassification,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Model,
|
||||
GPT2Tokenizer,
|
||||
)
|
||||
|
||||
|
||||
@ -425,6 +426,50 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Define PAD Token = EOS Token = 50256
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
# use different length sentences to test batching
|
||||
sentences = [
|
||||
"Hello, my dog is a little",
|
||||
"Today, I",
|
||||
]
|
||||
|
||||
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs = model.generate(
|
||||
input_ids=inputs["input_ids"].to(torch_device),
|
||||
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||
)
|
||||
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||
|
||||
expected_output_sentence = [
|
||||
"Hello, my dog is a little bit of a mess. I'm not sure if he's going",
|
||||
"Today, I'm going to be doing a lot of research on this. I",
|
||||
]
|
||||
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
Loading…
Reference in New Issue
Block a user