mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
setup training
This commit is contained in:
parent
4735c2af07
commit
9f75565ea8
@ -10,5 +10,3 @@ regex
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
# For XLM
|
# For XLM
|
||||||
sacremoses
|
sacremoses
|
||||||
# For ROUGE
|
|
||||||
pyrouge
|
|
||||||
|
@ -166,7 +166,7 @@ class BeamSearch(object):
|
|||||||
for step in range(self.max_length):
|
for step in range(self.max_length):
|
||||||
|
|
||||||
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
|
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
|
||||||
kwargs_decoder["attention_mask"] = build_mask(decoder_input)
|
kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id)
|
||||||
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
|
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
|
||||||
|
|
||||||
next_token_scores = outputs[0][:, -1, :].squeeze(1)
|
next_token_scores = outputs[0][:, -1, :].squeeze(1)
|
||||||
|
Loading…
Reference in New Issue
Block a user