Remove max length beam scorer (#11378)

* removed max_len

* removed max_length from BeamSearchScorer

* correct max length

* finish

* del vim

* finish & add test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Ashwin Geet D'Sa 2021-04-27 00:28:40 +02:00 committed by GitHub
parent bc2571e61c
commit 741d48f5c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 91 additions and 38 deletions

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Optional, Tuple
@ -110,6 +111,7 @@ class BeamScorer(ABC):
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
max_length: int,
**kwargs
) -> torch.LongTensor:
raise NotImplementedError("This is an abstract method.")
@ -152,15 +154,14 @@ class BeamSearchScorer(BeamScorer):
def __init__(
self,
batch_size: int,
max_length: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
**kwargs,
):
self.max_length = max_length
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
@ -173,7 +174,6 @@ class BeamSearchScorer(BeamScorer):
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
max_length=self.max_length,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
@ -192,6 +192,13 @@ class BeamSearchScorer(BeamScorer):
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect."
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
",or `group_beam_search(...)`."
)
@property
def is_done(self) -> bool:
return self._done.all()
@ -279,6 +286,7 @@ class BeamSearchScorer(BeamScorer):
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.LongTensor]:
@ -316,7 +324,7 @@ class BeamSearchScorer(BeamScorer):
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
@ -326,7 +334,7 @@ class BeamSearchScorer(BeamScorer):
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < self.max_length:
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
return UserDict(
{
@ -337,11 +345,10 @@ class BeamSearchScorer(BeamScorer):
class BeamHypotheses:
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams

View File

@ -1027,7 +1027,6 @@ class GenerationMixin:
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=stopping_criteria.max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
@ -1063,7 +1062,6 @@ class GenerationMixin:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=stopping_criteria.max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
@ -1700,7 +1698,6 @@ class GenerationMixin:
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... )
@ -1756,7 +1753,7 @@ class GenerationMixin:
assert (
num_beams * batch_size == batch_beam_size
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
@ -1792,10 +1789,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores)
@ -1861,7 +1855,13 @@ class GenerationMixin:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
)
if return_dict_in_generate:
@ -2086,10 +2086,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores)
@ -2160,7 +2157,13 @@ class GenerationMixin:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
)
if return_dict_in_generate:
@ -2411,10 +2414,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1]
@ -2497,7 +2497,13 @@ class GenerationMixin:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
)
if return_dict_in_generate:

View File

@ -1335,7 +1335,7 @@ class MarianMTModel(MarianPreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
def adjust_logits_during_generation(self, logits, cur_len):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
return logits

View File

@ -1543,7 +1543,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,

View File

@ -59,7 +59,6 @@ class BeamSearchTester:
def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size),
max_length=kwargs.get("max_length", self.max_length),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
@ -170,9 +169,7 @@ class BeamSearchTester:
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer(
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
)
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)
# update beams and append to input_ids
tokens = next_tokens.clone()
@ -197,6 +194,7 @@ class BeamSearchTester:
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
@ -225,6 +223,7 @@ class BeamSearchTester:
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]

View File

@ -148,7 +148,6 @@ class GenerationTesterMixin:
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
@ -169,7 +168,6 @@ class GenerationTesterMixin:
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
@ -1411,7 +1409,6 @@ class GenerationIntegrationTests(unittest.TestCase):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
)
@ -1442,7 +1439,6 @@ class GenerationIntegrationTests(unittest.TestCase):
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
@ -1502,7 +1498,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# Beam
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
)
@ -1520,7 +1515,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# Grouped beam search
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
@ -1535,3 +1529,51 @@ class GenerationIntegrationTests(unittest.TestCase):
max_length=max_length,
**model_kwargs,
)
def test_beam_search_warning_if_max_length_is_passed(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
batch_size = 1
num_beams = 3
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
input_ids = input_ids.expand(num_beams, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
with self.assertWarns(UserWarning):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
max_length=10,
)
generated_ids = bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer,
**model_kwargs,
)
beam_scorer_no_max_len = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)
generated_ids_no_max_len = bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer_no_max_len,
**model_kwargs,
)
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())