separate inputs into encoder & decoder inputs

This commit is contained in:
Rémi Louf 2019-10-16 20:55:42 +02:00
parent e4e0ee14bd
commit 95ec1d08be

View File

@ -130,7 +130,7 @@ class PreTrainedSeq2seq(nn.Module):
return model
def forward(self, *inputs, **kwargs):
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
""" The forward pass on a seq2eq depends what we are performing:
- During training we perform one forward pass through both the encoder
@ -142,6 +142,11 @@ class PreTrainedSeq2seq(nn.Module):
Therefore, we skip the forward pass on the encoder if an argument named
`encoder_hidden_state` is passed to this function.
Params:
encoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of encoder input sequence tokens in the vocabulary.
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of decoder input sequence tokens in the vocabulary.
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
@ -154,14 +159,14 @@ class PreTrainedSeq2seq(nn.Module):
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(*inputs, **kwargs_encoder)
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = ()
# Decode
kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states
decoder_outputs = self.decoder(**kwargs_decoder)
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs