mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
finalized PR
This commit is contained in:
parent
2acfe63964
commit
d880a5fbde
@ -798,7 +798,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
||||
|
||||
# create attention mask if necessary
|
||||
# TODO (PVP): this should later be handled by the forward fn() in each model
|
||||
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
|
||||
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
|
||||
attention_mask = input_ids.ne(pad_token_id).long()
|
||||
elif attention_mask is None:
|
||||
@ -989,10 +989,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if unfinished_sents.max() == 0:
|
||||
break
|
||||
|
||||
# extend attention_mask for new generated input
|
||||
# extend attention_mask for new generated input if only decoder
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
|
||||
cur_len = cur_len + 1
|
||||
@ -1078,7 +1078,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
|
||||
if (
|
||||
self.config.is_encoder_decoder
|
||||
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search?
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
@ -1205,10 +1207,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if past:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
# extend attention_mask for new generated input
|
||||
# extend attention_mask for new generated input if only decoder
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
|
||||
# update current length
|
||||
@ -1270,7 +1272,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
# do not return first <BOS> token
|
||||
# do not return first <EOS> token
|
||||
return decoded[:, 1:]
|
||||
return decoded
|
||||
|
||||
|
@ -453,9 +453,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway."
|
||||
|
||||
dct = tok.batch_encode_plus(
|
||||
# [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
|
||||
[IRAN_ARTICLE, ARTICLE_SUBWAY],
|
||||
# [FRANCE_ARTICLE, SHORTER_ARTICLE],
|
||||
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
|
||||
max_length=1024,
|
||||
pad_to_max_length=True,
|
||||
return_tensors="pt",
|
||||
@ -482,9 +480,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
|
||||
self.assertListEqual(
|
||||
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
||||
[EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
||||
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER],
|
||||
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
||||
decoded,
|
||||
)
|
||||
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
|
||||
|
Loading…
Reference in New Issue
Block a user