mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
update forward pass
This commit is contained in:
parent
0ef9bc923a
commit
bfbe68f035
@ -218,12 +218,14 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
encoder_outputs = self.encoder(*inputs, *kwargs)
|
encoder_outputs = self.encoder(*inputs, *kwargs)
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
else:
|
||||||
|
encoder_outputs = (,)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states
|
decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states
|
||||||
decoder_outputs = self.decoder(**decoder_kwargs)
|
decoder_outputs = self.decoder(**decoder_kwargs)
|
||||||
|
|
||||||
return decoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
|
|
||||||
class Model2Model(PreTrainedSeq2seq):
|
class Model2Model(PreTrainedSeq2seq):
|
||||||
|
Loading…
Reference in New Issue
Block a user