mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* Make sure position ids are masked * test that padded input produce the same results * fix failing tests * fixup * fix batch test
This commit is contained in:
parent
eee195b3aa
commit
a3fef89b26
@ -553,7 +553,14 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
|
||||
if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_length > 0:
|
||||
position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
|
||||
elif position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
|
@ -797,7 +797,14 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
|
||||
if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_length > 0:
|
||||
position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
|
||||
elif position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
|
@ -590,6 +590,27 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
|
||||
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
||||
|
||||
@slow
|
||||
def test_batch_forward(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# This tokenizer has no pad token, so we have to set it in some way
|
||||
# Define PAD Token = EOS Token = 50256
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
|
||||
sentences = ["Hello, my dog is a little bit of a mess. I'm not sure if he's"]
|
||||
inputs = tokenizer(sentences, padding=True, return_tensors="pt")
|
||||
logits = model(**inputs).logits[:, -1, :]
|
||||
indexes = torch.argmax(logits).item()
|
||||
|
||||
inputs_padded = tokenizer(sentences, padding="max_length", max_length=30, return_tensors="pt")
|
||||
logits_padded = model(**inputs_padded).logits[:, -1, :]
|
||||
indexes_padded = torch.argmax(logits_padded).item()
|
||||
|
||||
self.assertTrue(indexes == indexes_padded)
|
||||
|
||||
@slow
|
||||
def test_batch_generation_2heads(self):
|
||||
model = GPT2DoubleHeadsModel.from_pretrained("gpt2")
|
||||
|
Loading…
Reference in New Issue
Block a user