update extract_distilbert

This commit is contained in:
VictorSanh 2019-10-02 11:01:33 -04:00 committed by Victor SANH
parent cbfcfce205
commit 23edebc079

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DistilBERT. Preprocessing script before training DistilBERT.
Specific to BERT -> DistilBERT.
""" """
from transformers import BertForMaskedLM, RobertaForMaskedLM from transformers import BertForMaskedLM, RobertaForMaskedLM
import torch import torch
@ -21,7 +22,7 @@ import argparse
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation") parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
parser.add_argument("--model_type", default="bert", choices=["bert", "roberta"]) parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default='bert-base-uncased', type=str) parser.add_argument("--model_name", default='bert-base-uncased', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str) parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action='store_true')
@ -31,9 +32,8 @@ if __name__ == '__main__':
if args.model_type == 'bert': if args.model_type == 'bert':
model = BertForMaskedLM.from_pretrained(args.model_name) model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert' prefix = 'bert'
elif args.model_type == 'roberta': else:
model = RobertaForMaskedLM.from_pretrained(args.model_name) raise ValueError(f'args.model_type should be "bert".')
prefix = 'roberta'
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
@ -68,20 +68,12 @@ if __name__ == '__main__':
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
std_idx += 1 std_idx += 1
if args.model_type == 'bert': compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] if args.vocab_transform:
if args.vocab_transform: for w in ['weight', 'bias']:
for w in ['weight', 'bias']: compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
elif args.model_type == 'roberta':
compressed_sd[f'vocab_projector.weight'] = state_dict[f'lm_head.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'lm_head.bias']
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'lm_head.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
print(f'N layers selected for distillation: {std_idx}') print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')