mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
finalize generation merge
This commit is contained in:
parent
1ba21f96ca
commit
a332cc9f7f
@ -40,8 +40,9 @@ class BartConfig(PretrainedConfig):
|
||||
self,
|
||||
activation_dropout=0.0,
|
||||
vocab_size=50265,
|
||||
bos_token_id=0,
|
||||
pad_token_id=1,
|
||||
eos_token_id=2,
|
||||
eos_token_ids=[2],
|
||||
d_model=1024,
|
||||
encoder_ffn_dim=4096,
|
||||
encoder_layers=12,
|
||||
@ -58,7 +59,6 @@ class BartConfig(PretrainedConfig):
|
||||
classifier_dropout=0.0,
|
||||
output_past=False,
|
||||
num_labels=3,
|
||||
bos_token_id=0,
|
||||
is_encoder_decoder=True,
|
||||
**common_kwargs
|
||||
):
|
||||
@ -73,12 +73,12 @@ class BartConfig(PretrainedConfig):
|
||||
output_past=output_past,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
**common_kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
|
||||
self.eos_token_id = eos_token_id
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = self.num_hidden_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
|
@ -962,8 +962,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
if cur_len == 1:
|
||||
self._force_token_ids_generation(scores, self.config.bos_token_id)
|
||||
if cur_len == max_length - 1:
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_ids)
|
||||
if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None:
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_ids[0])
|
||||
return scores
|
||||
|
||||
@staticmethod
|
||||
@ -1056,7 +1056,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
encoder_outputs=encoder_outputs,
|
||||
)
|
||||
x = outputs[0] # last hidden state
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||
eos_mask = input_ids.eq(self.config.eos_token_ids[0])
|
||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
|
||||
|
@ -840,14 +840,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
eos_token_id = eos_token_ids[0]
|
||||
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
|
||||
assert eos_token_id is not None, "Encoder Decoder Models need to have a eos_token_id"
|
||||
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
|
||||
encoder_inputs = input_ids
|
||||
input_ids = torch.full(
|
||||
(effective_batch_size * num_beams, 1),
|
||||
# eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
|
||||
bos_token_id,
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
|
@ -82,7 +82,7 @@ class ModelTester:
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
eos_token_ids=[self.eos_token_id],
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
output_past=output_past,
|
||||
eos_token_id=2,
|
||||
eos_token_ids=[2],
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
)
|
||||
@ -276,7 +276,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
output_past=True,
|
||||
eos_token_ids=2,
|
||||
eos_token_ids=[2],
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
)
|
||||
@ -287,7 +287,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
new_input_ids = lm_model.generate(
|
||||
input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length
|
||||
)
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1))
|
||||
# TODO(SS): uneven length batches, empty inputs
|
||||
|
||||
def test_shift_tokens_right(self):
|
||||
|
Loading…
Reference in New Issue
Block a user