mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
changed do_output_past function to check for self.config.output_past instead of self.output_past
This commit is contained in:
parent
eeaa402cd4
commit
7e0c5c731a
@ -539,10 +539,10 @@ class PreTrainedModel(nn.Module):
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
def _has_past(self, outputs):
|
||||
# TODO: might be better to write a self.has_past method for each individual class as is done for
|
||||
def _do_output_past(self, outputs):
|
||||
# TODO: might be better to write a self.do_output_past method for each individual class as is done for
|
||||
# prepare_inputs_for_generation
|
||||
if hasattr(self, 'output_past') and self.output_past and len(outputs) > 1:
|
||||
if hasattr(self.config, 'output_past') and self.config.output_past and len(outputs) > 1 and not hasattr(self, 'mem_len'):
|
||||
return True
|
||||
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
|
||||
return False
|
||||
@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module):
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._has_past(outputs):
|
||||
if self._do_output_past(outputs):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
@ -819,7 +819,7 @@ class PreTrainedModel(nn.Module):
|
||||
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._has_past(outputs):
|
||||
if self._do_output_past(outputs):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
|
Loading…
Reference in New Issue
Block a user