Re-ordering of group_idx/layer_idx + Python 2 tests

This commit is contained in:
Lysandre 2019-11-07 19:55:43 +00:00 committed by Lysandre Debut
parent 9d5c49546f
commit d9daad98c7

View File

@ -281,11 +281,17 @@ class AlbertTransformer(nn.Module):
if self.output_hidden_states:
all_hidden_states = (hidden_states,)
for layer_idx in range(self.config.num_hidden_layers):
group_idx = int(layer_idx / self.config.num_hidden_layers * self.config.num_hidden_groups)
for i in range(self.config.num_hidden_layers):
# Number of layers in a hidden group
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group])
# Index of the hidden group
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
# Index of the layer inside the group
layer_idx = int(i - group_idx * layers_per_group)
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group])
hidden_states = layer_group_output[0]
if self.output_attentions: