mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Bart] dont call .forward (#3094)
This commit is contained in:
parent
a088d75e51
commit
5c5af879b6
@ -208,7 +208,7 @@ class EncoderLayer(nn.Module):
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
x, attn_weights = self.self_attn.forward(
|
||||
x, attn_weights = self.self_attn(
|
||||
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
@ -292,7 +292,7 @@ class BartEncoder(nn.Module):
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
attn = None
|
||||
else:
|
||||
x, attn = encoder_layer.forward(x, attention_mask)
|
||||
x, attn = encoder_layer(x, attention_mask)
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions.append(attn)
|
||||
@ -356,7 +356,7 @@ class DecoderLayer(nn.Module):
|
||||
if layer_state is None:
|
||||
layer_state = {}
|
||||
# next line mutates layer state
|
||||
x, self_attn_weights = self.self_attn.forward(
|
||||
x, self_attn_weights = self.self_attn(
|
||||
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
@ -365,7 +365,7 @@ class DecoderLayer(nn.Module):
|
||||
residual = x
|
||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||
|
||||
x, encoder_attn_weights = self.encoder_attn.forward(
|
||||
x, encoder_attn_weights = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_hidden_states, # could be None
|
||||
value=encoder_hidden_states,
|
||||
@ -449,7 +449,7 @@ class BartDecoder(nn.Module):
|
||||
- attentions
|
||||
"""
|
||||
# embed positions
|
||||
positions = self.embed_positions.forward(input_ids, generation_mode=self.generation_mode)
|
||||
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
|
||||
|
||||
if self.generation_mode:
|
||||
input_ids = input_ids[:, -1:]
|
||||
@ -475,7 +475,7 @@ class BartDecoder(nn.Module):
|
||||
continue
|
||||
|
||||
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
||||
x, layer_self_attn, layer_past = decoder_layer.forward(
|
||||
x, layer_self_attn, layer_past = decoder_layer(
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
encoder_padding_mask,
|
||||
@ -836,10 +836,10 @@ class BartModel(PretrainedBartModel):
|
||||
)
|
||||
assert decoder_input_ids is not None
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder.forward(input_ids=input_ids, attention_mask=attention_mask)
|
||||
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
assert isinstance(encoder_outputs, tuple)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
decoder_outputs = self.decoder.forward(
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids,
|
||||
encoder_outputs[0],
|
||||
attention_mask,
|
||||
@ -925,7 +925,7 @@ class BartForMaskedLM(PretrainedBartModel):
|
||||
outputs = model(input_ids=input_ids, lm_labels=input_ids)
|
||||
loss, prediction_scores = outputs[:2]
|
||||
"""
|
||||
outputs = self.model.forward(
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
@ -933,7 +933,7 @@ class BartForMaskedLM(PretrainedBartModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
)
|
||||
lm_logits = self.lm_head.forward(outputs[0])
|
||||
lm_logits = self.lm_head(outputs[0])
|
||||
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||
if lm_labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
@ -1308,7 +1308,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
outputs = self.model.forward(
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user