update forward pass

This commit is contained in:
thomwolf 2019-10-14 12:04:23 +02:00
parent 0ef9bc923a
commit bfbe68f035

View File

@ -218,12 +218,14 @@ class PreTrainedSeq2seq(nn.Module):
if encoder_hidden_states is None:
encoder_outputs = self.encoder(*inputs, *kwargs)
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = (,)
# Decode
decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states
decoder_outputs = self.decoder(**decoder_kwargs)
return decoder_outputs
return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedSeq2seq):