mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
update extract_distilbert
This commit is contained in:
parent
cbfcfce205
commit
23edebc079
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Preprocessing script before training DistilBERT.
|
Preprocessing script before training DistilBERT.
|
||||||
|
Specific to BERT -> DistilBERT.
|
||||||
"""
|
"""
|
||||||
from transformers import BertForMaskedLM, RobertaForMaskedLM
|
from transformers import BertForMaskedLM, RobertaForMaskedLM
|
||||||
import torch
|
import torch
|
||||||
@ -21,7 +22,7 @@ import argparse
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
|
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
|
||||||
parser.add_argument("--model_type", default="bert", choices=["bert", "roberta"])
|
parser.add_argument("--model_type", default="bert", choices=["bert"])
|
||||||
parser.add_argument("--model_name", default='bert-base-uncased', type=str)
|
parser.add_argument("--model_name", default='bert-base-uncased', type=str)
|
||||||
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
|
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
|
||||||
parser.add_argument("--vocab_transform", action='store_true')
|
parser.add_argument("--vocab_transform", action='store_true')
|
||||||
@ -31,9 +32,8 @@ if __name__ == '__main__':
|
|||||||
if args.model_type == 'bert':
|
if args.model_type == 'bert':
|
||||||
model = BertForMaskedLM.from_pretrained(args.model_name)
|
model = BertForMaskedLM.from_pretrained(args.model_name)
|
||||||
prefix = 'bert'
|
prefix = 'bert'
|
||||||
elif args.model_type == 'roberta':
|
else:
|
||||||
model = RobertaForMaskedLM.from_pretrained(args.model_name)
|
raise ValueError(f'args.model_type should be "bert".')
|
||||||
prefix = 'roberta'
|
|
||||||
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
compressed_sd = {}
|
compressed_sd = {}
|
||||||
@ -68,20 +68,12 @@ if __name__ == '__main__':
|
|||||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
|
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
|
||||||
std_idx += 1
|
std_idx += 1
|
||||||
|
|
||||||
if args.model_type == 'bert':
|
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
|
||||||
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
|
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
|
||||||
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
|
if args.vocab_transform:
|
||||||
if args.vocab_transform:
|
for w in ['weight', 'bias']:
|
||||||
for w in ['weight', 'bias']:
|
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
|
||||||
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
|
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
|
||||||
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
|
|
||||||
elif args.model_type == 'roberta':
|
|
||||||
compressed_sd[f'vocab_projector.weight'] = state_dict[f'lm_head.decoder.weight']
|
|
||||||
compressed_sd[f'vocab_projector.bias'] = state_dict[f'lm_head.bias']
|
|
||||||
if args.vocab_transform:
|
|
||||||
for w in ['weight', 'bias']:
|
|
||||||
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'lm_head.dense.{w}']
|
|
||||||
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
|
|
||||||
|
|
||||||
print(f'N layers selected for distillation: {std_idx}')
|
print(f'N layers selected for distillation: {std_idx}')
|
||||||
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
|
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
|
Loading…
Reference in New Issue
Block a user