diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 08bf9d82d0d..b33691d6462 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -450,6 +450,8 @@ class AlbertTransformer(nn.Module): all_hidden_states = (hidden_states,) if output_hidden_states else None all_attentions = () if output_attentions else None + head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask + for i in range(self.config.num_hidden_layers): # Number of layers in a hidden group layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)