changed do_output_past function to check for self.config.output_past instead of self.output_past

This commit is contained in:
patrickvonplaten 2019-12-23 22:33:45 +01:00
parent eeaa402cd4
commit 7e0c5c731a

View File

@ -539,10 +539,10 @@ class PreTrainedModel(nn.Module):
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def _has_past(self, outputs):
# TODO: might be better to write a self.has_past method for each individual class as is done for
def _do_output_past(self, outputs):
# TODO: might be better to write a self.do_output_past method for each individual class as is done for
# prepare_inputs_for_generation
if hasattr(self, 'output_past') and self.output_past and len(outputs) > 1:
if hasattr(self.config, 'output_past') and self.config.output_past and len(outputs) > 1 and not hasattr(self, 'mem_len'):
return True
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
return False
@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module):
next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding
if self._has_past(outputs):
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
@ -819,7 +819,7 @@ class PreTrainedModel(nn.Module):
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._has_past(outputs):
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)