[Generate] Small refactor (#15611)

This commit is contained in:
Patrick von Platen 2022-02-10 18:29:27 +01:00 committed by GitHub
parent c0864d98ba
commit 45c7b5b1c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -525,12 +525,6 @@ class GenerationMixin:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
return pad_token_id
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
@ -1063,9 +1057,15 @@ class GenerationMixin:
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if eos_token_id is None and hasattr(self.config, "decoder"):
eos_token_id = self.config.decoder.eos_token_id
if pad_token_id is None and eos_token_id is not None:
# special case if pad_token_id is not defined
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1075,11 +1075,6 @@ class GenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
if pad_token_id is None and eos_token_id is not None:
# special case if pad_token_id is not defined
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
# 2. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed