mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove max length beam scorer (#11378)
* removed max_len * removed max_length from BeamSearchScorer * correct max length * finish * del vim * finish & add test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
bc2571e61c
commit
741d48f5c7
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from typing import Optional, Tuple
|
||||
@ -110,6 +111,7 @@ class BeamScorer(ABC):
|
||||
next_scores: torch.FloatTensor,
|
||||
next_tokens: torch.LongTensor,
|
||||
next_indices: torch.LongTensor,
|
||||
max_length: int,
|
||||
**kwargs
|
||||
) -> torch.LongTensor:
|
||||
raise NotImplementedError("This is an abstract method.")
|
||||
@ -152,15 +154,14 @@ class BeamSearchScorer(BeamScorer):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
num_beams: int,
|
||||
device: torch.device,
|
||||
length_penalty: Optional[float] = 1.0,
|
||||
do_early_stopping: Optional[bool] = False,
|
||||
num_beam_hyps_to_keep: Optional[int] = 1,
|
||||
num_beam_groups: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
self.device = device
|
||||
self.length_penalty = length_penalty
|
||||
@ -173,7 +174,6 @@ class BeamSearchScorer(BeamScorer):
|
||||
self._beam_hyps = [
|
||||
BeamHypotheses(
|
||||
num_beams=self.num_beams,
|
||||
max_length=self.max_length,
|
||||
length_penalty=self.length_penalty,
|
||||
early_stopping=self.do_early_stopping,
|
||||
)
|
||||
@ -192,6 +192,13 @@ class BeamSearchScorer(BeamScorer):
|
||||
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||
)
|
||||
|
||||
if "max_length" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect."
|
||||
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
|
||||
",or `group_beam_search(...)`."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self._done.all()
|
||||
@ -279,6 +286,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
final_beam_scores: torch.FloatTensor,
|
||||
final_beam_tokens: torch.LongTensor,
|
||||
final_beam_indices: torch.LongTensor,
|
||||
max_length: int,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> Tuple[torch.LongTensor]:
|
||||
@ -316,7 +324,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||
|
||||
# prepare for adding eos
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||
# shorter batches are padded if needed
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
@ -326,7 +334,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
# fill with hypotheses and eos_token_id if the latter fits in
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < self.max_length:
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
return UserDict(
|
||||
{
|
||||
@ -337,11 +345,10 @@ class BeamSearchScorer(BeamScorer):
|
||||
|
||||
|
||||
class BeamHypotheses:
|
||||
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
|
||||
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.max_length = max_length - 1 # ignoring bos_token
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.num_beams = num_beams
|
||||
|
@ -1027,7 +1027,6 @@ class GenerationMixin:
|
||||
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=stopping_criteria.max_length,
|
||||
num_beams=num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
@ -1063,7 +1062,6 @@ class GenerationMixin:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=stopping_criteria.max_length,
|
||||
num_beams=num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
@ -1700,7 +1698,6 @@ class GenerationMixin:
|
||||
>>> # instantiate beam scorer
|
||||
>>> beam_scorer = BeamSearchScorer(
|
||||
... batch_size=1,
|
||||
... max_length=model.config.max_length,
|
||||
... num_beams=num_beams,
|
||||
... device=model.device,
|
||||
... )
|
||||
@ -1756,7 +1753,7 @@ class GenerationMixin:
|
||||
|
||||
assert (
|
||||
num_beams * batch_size == batch_beam_size
|
||||
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
@ -1792,10 +1789,7 @@ class GenerationMixin:
|
||||
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=None
|
||||
)
|
||||
|
||||
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
|
||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores = logits_processor(input_ids, next_token_scores)
|
||||
@ -1861,7 +1855,13 @@ class GenerationMixin:
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
input_ids,
|
||||
beam_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
@ -2086,10 +2086,7 @@ class GenerationMixin:
|
||||
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=None
|
||||
)
|
||||
|
||||
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
|
||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores = logits_processor(input_ids, next_token_scores)
|
||||
@ -2160,7 +2157,13 @@ class GenerationMixin:
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
input_ids,
|
||||
beam_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
@ -2411,10 +2414,7 @@ class GenerationMixin:
|
||||
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=None
|
||||
)
|
||||
|
||||
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
|
||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
|
||||
vocab_size = next_token_scores.shape[-1]
|
||||
|
||||
@ -2497,7 +2497,13 @@ class GenerationMixin:
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
input_ids,
|
||||
beam_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
|
@ -1335,7 +1335,7 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
def adjust_logits_during_generation(self, logits, cur_len):
|
||||
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
||||
return logits
|
||||
|
||||
|
@ -1543,7 +1543,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
|
@ -59,7 +59,6 @@ class BeamSearchTester:
|
||||
def prepare_beam_scorer(self, **kwargs):
|
||||
return BeamSearchScorer(
|
||||
batch_size=kwargs.get("batch_size", self.batch_size),
|
||||
max_length=kwargs.get("max_length", self.max_length),
|
||||
num_beams=kwargs.get("num_beams", self.num_beams),
|
||||
device=torch_device,
|
||||
length_penalty=kwargs.get("length_penalty", self.length_penalty),
|
||||
@ -170,9 +169,7 @@ class BeamSearchTester:
|
||||
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
||||
# max_length should be only one more than current input_ids to check that eos is correctly appended
|
||||
max_length = self.sequence_length + 1
|
||||
beam_scorer = self.prepare_beam_scorer(
|
||||
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
|
||||
)
|
||||
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)
|
||||
|
||||
# update beams and append to input_ids
|
||||
tokens = next_tokens.clone()
|
||||
@ -197,6 +194,7 @@ class BeamSearchTester:
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
sequences = sequence_output["sequences"]
|
||||
@ -225,6 +223,7 @@ class BeamSearchTester:
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
|
@ -148,7 +148,6 @@ class GenerationTesterMixin:
|
||||
}
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
device=torch_device,
|
||||
length_penalty=beam_kwargs["length_penalty"],
|
||||
@ -169,7 +168,6 @@ class GenerationTesterMixin:
|
||||
}
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
device=torch_device,
|
||||
length_penalty=beam_kwargs["length_penalty"],
|
||||
@ -1411,7 +1409,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
@ -1442,7 +1439,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
@ -1502,7 +1498,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# Beam
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
@ -1520,7 +1515,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# Grouped beam search
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
@ -1535,3 +1529,51 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
max_length=max_length,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_beam_search_warning_if_max_length_is_passed(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
num_beams = 3
|
||||
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids = input_ids.expand(num_beams, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
|
||||
stopping_criteria_max_length = 18
|
||||
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
generated_ids = bart_model.beam_search(
|
||||
input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
beam_scorer=beam_scorer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
beam_scorer_no_max_len = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
generated_ids_no_max_len = bart_model.beam_search(
|
||||
input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
beam_scorer=beam_scorer_no_max_len,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# BeamSearchScorer max_length should not influence "real" max_length
|
||||
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
|
||||
|
Loading…
Reference in New Issue
Block a user