transformers/examples/distillation/scripts/extract_for_distil.py
2019-08-28 04:00:19 +00:00

60 lines
3.2 KiB
Python

from pytorch_transformers import BertForPreTraining
import torch
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForPreTraining for Transfer Learned Distillation")
parser.add_argument("--bert_model", default='bert-base-uncased', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/transfer_learning_checkpoint_0247911.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true')
args = parser.parse_args()
model = BertForPreTraining.from_pretrained(args.bert_model)
state_dict = model.state_dict()
compressed_sd = {}
for w in ['word_embeddings', 'position_embeddings']:
compressed_sd[f'dilbert.embeddings.{w}.weight'] = \
state_dict[f'bert.embeddings.{w}.weight']
for w in ['weight', 'bias']:
compressed_sd[f'dilbert.embeddings.LayerNorm.{w}'] = \
state_dict[f'bert.embeddings.LayerNorm.{w}']
std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ['weight', 'bias']:
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
std_idx += 1
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
print(f'Save transfered checkpoint to {args.dump_checkpoint}.')
torch.save(compressed_sd, args.dump_checkpoint)