mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Re-ordering of group_idx/layer_idx + Python 2 tests
This commit is contained in:
parent
9d5c49546f
commit
d9daad98c7
@ -281,11 +281,17 @@ class AlbertTransformer(nn.Module):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = (hidden_states,)
|
||||
|
||||
for layer_idx in range(self.config.num_hidden_layers):
|
||||
group_idx = int(layer_idx / self.config.num_hidden_layers * self.config.num_hidden_groups)
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
# Number of layers in a hidden group
|
||||
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
|
||||
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group])
|
||||
|
||||
# Index of the hidden group
|
||||
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
||||
|
||||
# Index of the layer inside the group
|
||||
layer_idx = int(i - group_idx * layers_per_group)
|
||||
|
||||
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group])
|
||||
hidden_states = layer_group_output[0]
|
||||
|
||||
if self.output_attentions:
|
||||
|
Loading…
Reference in New Issue
Block a user