diff --git a/modeling.py b/modeling.py index 43db3b30fb3..9c6fa38e051 100644 --- a/modeling.py +++ b/modeling.py @@ -337,8 +337,8 @@ class BertModel(nn.Module): token_type_ids = torch.zeros_like(input_ids) # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, from_seq_length] - # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)