Conditional append/init + fixed warning

This commit is contained in:
LysandreJik 2019-08-21 21:37:30 -04:00
parent 5c2b94c82a
commit c85b5db61a

View File

@ -379,11 +379,15 @@ class PreTrainedModel(nn.Module):
for head in heads:
if head not in self.config.pruned_heads[int(layer)]:
self.config.pruned_heads[int(layer)].append(head)
to_be_pruned[int(layer)].append(head)
if int(layer) in to_be_pruned:
to_be_pruned[int(layer)].append(head)
else:
to_be_pruned[int(layer)] = [head]
else:
logger.warning("Tried to remove head " + head +
" of layer " + layer +
" but it was already removed. The current removed heads are " + heads_to_prune)
logger.warning("Tried to remove head " + str(head) +
" of layer " + str(layer) +
" but it was already removed. The current removed heads are " + str(heads_to_prune))
base_model._prune_heads(to_be_pruned)