[Bart] dont call .forward (#3094)

This commit is contained in:
Sam Shleifer 2020-03-03 15:14:12 -05:00 committed by GitHub
parent a088d75e51
commit 5c5af879b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -208,7 +208,7 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x, attn_weights = self.self_attn.forward(
x, attn_weights = self.self_attn(
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
@ -292,7 +292,7 @@ class BartEncoder(nn.Module):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None
else:
x, attn = encoder_layer.forward(x, attention_mask)
x, attn = encoder_layer(x, attention_mask)
if self.output_attentions:
all_attentions.append(attn)
@ -356,7 +356,7 @@ class DecoderLayer(nn.Module):
if layer_state is None:
layer_state = {}
# next line mutates layer state
x, self_attn_weights = self.self_attn.forward(
x, self_attn_weights = self.self_attn(
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
@ -365,7 +365,7 @@ class DecoderLayer(nn.Module):
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
x, encoder_attn_weights = self.encoder_attn.forward(
x, encoder_attn_weights = self.encoder_attn(
query=x,
key=encoder_hidden_states, # could be None
value=encoder_hidden_states,
@ -449,7 +449,7 @@ class BartDecoder(nn.Module):
- attentions
"""
# embed positions
positions = self.embed_positions.forward(input_ids, generation_mode=self.generation_mode)
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
if self.generation_mode:
input_ids = input_ids[:, -1:]
@ -475,7 +475,7 @@ class BartDecoder(nn.Module):
continue
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer.forward(
x, layer_self_attn, layer_past = decoder_layer(
x,
encoder_hidden_states,
encoder_padding_mask,
@ -836,10 +836,10 @@ class BartModel(PretrainedBartModel):
)
assert decoder_input_ids is not None
if encoder_outputs is None:
encoder_outputs = self.encoder.forward(input_ids=input_ids, attention_mask=attention_mask)
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
assert isinstance(encoder_outputs, tuple)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
decoder_outputs = self.decoder.forward(
decoder_outputs = self.decoder(
decoder_input_ids,
encoder_outputs[0],
attention_mask,
@ -925,7 +925,7 @@ class BartForMaskedLM(PretrainedBartModel):
outputs = model(input_ids=input_ids, lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
"""
outputs = self.model.forward(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
@ -933,7 +933,7 @@ class BartForMaskedLM(PretrainedBartModel):
decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
)
lm_logits = self.lm_head.forward(outputs[0])
lm_logits = self.lm_head(outputs[0])
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
if lm_labels is not None:
loss_fct = nn.CrossEntropyLoss()
@ -1308,7 +1308,7 @@ class BartForSequenceClassification(PretrainedBartModel):
loss, logits = outputs[:2]
"""
outputs = self.model.forward(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,