fix head_mask for albert encoder part(AlbertTransformer) (#11596)

* fix head mask for albert encoder part

* fix head_mask for albert encoder part
This commit is contained in:
baeseongsu 2021-05-06 15:18:02 +09:00 committed by GitHub
parent 864c1dfe34
commit c1780ce7a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -450,6 +450,8 @@ class AlbertTransformer(nn.Module):
all_hidden_states = (hidden_states,) if output_hidden_states else None
all_attentions = () if output_attentions else None
head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask
for i in range(self.config.num_hidden_layers):
# Number of layers in a hidden group
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)