mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[Generate] Small refactor (#15611)
This commit is contained in:
parent
c0864d98ba
commit
45c7b5b1c7
@ -525,12 +525,6 @@ class GenerationMixin:
|
||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
|
||||
|
||||
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
pad_token_id = eos_token_id
|
||||
return pad_token_id
|
||||
|
||||
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
||||
@ -1063,9 +1057,15 @@ class GenerationMixin:
|
||||
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
if eos_token_id is None and hasattr(self.config, "decoder"):
|
||||
eos_token_id = self.config.decoder.eos_token_id
|
||||
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
# special case if pad_token_id is not defined
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1075,11 +1075,6 @@ class GenerationMixin:
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
# special case if pad_token_id is not defined
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
# 2. Define model inputs
|
||||
# inputs_tensor has to be defined
|
||||
# model_input_name is defined if model-specific keyword input is passed
|
||||
|
Loading…
Reference in New Issue
Block a user