mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
separate inputs into encoder & decoder inputs
This commit is contained in:
parent
e4e0ee14bd
commit
95ec1d08be
@ -130,7 +130,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
|
||||
return model
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||
""" The forward pass on a seq2eq depends what we are performing:
|
||||
|
||||
- During training we perform one forward pass through both the encoder
|
||||
@ -142,6 +142,11 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
Therefore, we skip the forward pass on the encoder if an argument named
|
||||
`encoder_hidden_state` is passed to this function.
|
||||
|
||||
Params:
|
||||
encoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
|
||||
Indices of encoder input sequence tokens in the vocabulary.
|
||||
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
"""
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
@ -154,14 +159,14 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(*inputs, **kwargs_encoder)
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states
|
||||
decoder_outputs = self.decoder(**kwargs_decoder)
|
||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user