diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index ae53e54ee8f..41150e4c178 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -559,7 +559,7 @@ class ProphetNetPreTrainedModel(PreTrainedModel): return shifted_input_ids -class ProhpetNetPositionalEmbeddings(nn.Embedding): +class ProphetNetPositionalEmbeddings(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to @@ -598,7 +598,7 @@ class ProhpetNetPositionalEmbeddings(nn.Embedding): return super().forward(position_ids) -class ProphetNetSelfAttention(nn.Module): +class ProphetNetAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( @@ -726,7 +726,7 @@ class ProphetNetSelfAttention(nn.Module): return attn_output, attn_weights_reshaped -class ProhpetNetFeedForward(nn.Module): +class ProphetNetFeedForward(nn.Module): """ This is the residual two feed-forward layer block based on the original Transformer implementation. """ @@ -749,14 +749,14 @@ class ProhpetNetFeedForward(nn.Module): return hidden_states -class ProphetNetNgramProphetNetSelfAttention(nn.Module): +class ProphetNetNgramSelfAttention(nn.Module): def __init__(self, config: ProphetNetConfig): super().__init__() self.hidden_size = config.hidden_size self.num_buckets = config.num_buckets self.relative_max_distance = config.relative_max_distance - self.num_attn_heads = config.num_attention_heads + self.num_attn_heads = config.num_decoder_attention_heads self.dropout = config.dropout self.attention_dropout = config.attention_dropout self.head_dim = config.hidden_size // self.num_attn_heads @@ -1046,11 +1046,11 @@ class ProphetNetEncoderLayer(nn.Module): def __init__(self, config: ProphetNetConfig): super().__init__() # 1st residual block - self.self_attn = ProphetNetSelfAttention(config, config.num_encoder_attention_heads) + self.self_attn = ProphetNetAttention(config, config.num_encoder_attention_heads) self.self_attn_layer_norm = LayerNorm(config.hidden_size) # 2nd residual block - self.feed_forward = ProhpetNetFeedForward(config, config.encoder_ffn_dim) + self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim) self.feed_forward_layer_norm = LayerNorm(config.hidden_size) def forward(self, hidden_states, attention_mask): @@ -1075,16 +1075,16 @@ class ProphetNetDecoderLayer(nn.Module): def __init__(self, config: ProphetNetConfig): super().__init__() # 1st residual block - self.self_attn = ProphetNetNgramProphetNetSelfAttention(config) + self.self_attn = ProphetNetNgramSelfAttention(config) self.self_attn_layer_norm = LayerNorm(config.hidden_size) # 2nd residual block if config.add_cross_attention: - self.cross_attn = ProphetNetSelfAttention(config, config.num_decoder_attention_heads) + self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads) self.cross_attn_layer_norm = LayerNorm(config.hidden_size) # 3rd residual block - self.feed_forward = ProhpetNetFeedForward(config, config.decoder_ffn_dim) + self.feed_forward = ProphetNetFeedForward(config, config.decoder_ffn_dim) self.feed_forward_layer_norm = LayerNorm(config.hidden_size) def forward( @@ -1156,7 +1156,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): if word_embeddings is not None else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) ) - self.position_embeddings = ProhpetNetPositionalEmbeddings(config) + self.position_embeddings = ProphetNetPositionalEmbeddings(config) self.embeddings_layer_norm = LayerNorm(config.hidden_size) self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) @@ -1212,7 +1212,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): # prepare attention mask if attention_mask is not None: extended_attention_mask = ( - 1.0 - attention_mask[:, None, :].repeat(self.config.num_attention_heads, 1, 1) + 1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1) ) * -10000.0 extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype) else: @@ -1273,7 +1273,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): if word_embeddings is not None else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) ) - self.position_embeddings = ProhpetNetPositionalEmbeddings(config) + self.position_embeddings = ProphetNetPositionalEmbeddings(config) self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)]) @@ -1397,7 +1397,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): # prepare encoder attention mask if encoder_attention_mask is not None: extended_encoder_attention_mask = ( - 1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_attention_heads, 1, 1) + 1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1) ) * -10000.0 extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype) else: