Conditional append/init + fixed warning

This commit is contained in:
LysandreJik 2019-08-21 21:37:30 -04:00
parent 5c2b94c82a
commit c85b5db61a

View File

@ -379,11 +379,15 @@ class PreTrainedModel(nn.Module):
for head in heads: for head in heads:
if head not in self.config.pruned_heads[int(layer)]: if head not in self.config.pruned_heads[int(layer)]:
self.config.pruned_heads[int(layer)].append(head) self.config.pruned_heads[int(layer)].append(head)
if int(layer) in to_be_pruned:
to_be_pruned[int(layer)].append(head) to_be_pruned[int(layer)].append(head)
else: else:
logger.warning("Tried to remove head " + head + to_be_pruned[int(layer)] = [head]
" of layer " + layer + else:
" but it was already removed. The current removed heads are " + heads_to_prune) logger.warning("Tried to remove head " + str(head) +
" of layer " + str(layer) +
" but it was already removed. The current removed heads are " + str(heads_to_prune))
base_model._prune_heads(to_be_pruned) base_model._prune_heads(to_be_pruned)