mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update forward pass
This commit is contained in:
parent
0ef9bc923a
commit
bfbe68f035
@ -218,12 +218,14 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(*inputs, *kwargs)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
else:
|
||||
encoder_outputs = (,)
|
||||
|
||||
# Decode
|
||||
decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states
|
||||
decoder_outputs = self.decoder(**decoder_kwargs)
|
||||
|
||||
return decoder_outputs
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedSeq2seq):
|
||||
|
Loading…
Reference in New Issue
Block a user